Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1790,6 +1790,7 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
:param cancel_on_kill: Flag which indicates whether cancel the workflow, when on_kill is called
"""

template_fields: Sequence[str] = ("template_id", "impersonation_chain", "request_id", "parameters")
Expand All @@ -1812,6 +1813,7 @@ def __init__(
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
polling_interval_seconds: int = 10,
cancel_on_kill: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -1830,6 +1832,8 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds
self.cancel_on_kill = cancel_on_kill
self.operation_name: str | None = None

def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
Expand All @@ -1845,24 +1849,26 @@ def execute(self, context: Context):
timeout=self.timeout,
metadata=self.metadata,
)
self.workflow_id = operation.operation.name.split("/")[-1]
operation_name = operation.operation.name
self.operation_name = operation_name
workflow_id = operation_name.split("/")[-1]
project_id = self.project_id or hook.project_id
if project_id:
DataprocWorkflowLink.persist(
context=context,
operator=self,
workflow_id=self.workflow_id,
workflow_id=workflow_id,
region=self.region,
project_id=project_id,
)
self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
self.log.info("Template instantiated. Workflow Id : %s", workflow_id)
if not self.deferrable:
hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
self.log.info("Workflow %s completed successfully", self.workflow_id)
self.log.info("Workflow %s completed successfully", workflow_id)
else:
self.defer(
trigger=DataprocWorkflowTrigger(
name=operation.operation.name,
name=operation_name,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -1884,6 +1890,11 @@ def execute_complete(self, context, event=None) -> None:

self.log.info("Workflow %s completed successfully", event["operation_name"])

def on_kill(self) -> None:
if self.cancel_on_kill and self.operation_name:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
hook.get_operations_client(region=self.region).cancel_operation(name=self.operation_name)


class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator):
"""Instantiate a WorkflowTemplate Inline on Google Cloud Dataproc.
Expand Down Expand Up @@ -1926,6 +1937,7 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
:param cancel_on_kill: Flag which indicates whether cancel the workflow, when on_kill is called
"""

template_fields: Sequence[str] = ("template", "impersonation_chain")
Expand All @@ -1946,6 +1958,7 @@ def __init__(
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
polling_interval_seconds: int = 10,
cancel_on_kill: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -1963,6 +1976,8 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds
self.cancel_on_kill = cancel_on_kill
self.operation_name: str | None = None

def execute(self, context: Context):
self.log.info("Instantiating Inline Template")
Expand All @@ -1977,23 +1992,25 @@ def execute(self, context: Context):
timeout=self.timeout,
metadata=self.metadata,
)
self.workflow_id = operation.operation.name.split("/")[-1]
operation_name = operation.operation.name
self.operation_name = operation_name
workflow_id = operation_name.split("/")[-1]
if project_id:
DataprocWorkflowLink.persist(
context=context,
operator=self,
workflow_id=self.workflow_id,
workflow_id=workflow_id,
region=self.region,
project_id=project_id,
)
if not self.deferrable:
self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
self.log.info("Template instantiated. Workflow Id : %s", workflow_id)
operation.result()
self.log.info("Workflow %s completed successfully", self.workflow_id)
self.log.info("Workflow %s completed successfully", workflow_id)
else:
self.defer(
trigger=DataprocWorkflowTrigger(
name=operation.operation.name,
name=operation_name,
project_id=self.project_id or hook.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -2015,6 +2032,11 @@ def execute_complete(self, context, event=None) -> None:

self.log.info("Workflow %s completed successfully", event["operation_name"])

def on_kill(self) -> None:
if self.cancel_on_kill and self.operation_name:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
hook.get_operations_client(region=self.region).cancel_operation(name=self.operation_name)


class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
"""Submit a job to a cluster.
Expand Down
64 changes: 63 additions & 1 deletion tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,7 +1399,7 @@ def test_update_cluster_operator_extra_links(dag_maker, create_task_instance_of_
assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED


class TestDataprocWorkflowTemplateInstantiateOperator:
class TestDataprocInstantiateWorkflowTemplateOperator:
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook):
version = 6
Expand Down Expand Up @@ -1463,6 +1463,37 @@ def test_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
assert isinstance(exc.value.trigger, DataprocWorkflowTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_on_kill(self, mock_hook):
operation_name = "operation_name"
mock_hook.return_value.instantiate_workflow_template.return_value.operation.name = operation_name
op = DataprocInstantiateWorkflowTemplateOperator(
task_id=TASK_ID,
template_id=TEMPLATE_ID,
region=GCP_REGION,
project_id=GCP_PROJECT,
version=2,
parameters={},
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
cancel_on_kill=False,
)

op.execute(context=mock.MagicMock())

op.on_kill()
mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_not_called()

op.cancel_on_kill = True
op.on_kill()
mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_called_once_with(
name=operation_name
)


@pytest.mark.need_serialized_dag
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
Expand Down Expand Up @@ -1561,6 +1592,37 @@ def test_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
assert isinstance(exc.value.trigger, DataprocWorkflowTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_on_kill(self, mock_hook):
operation_name = "operation_name"
mock_hook.return_value.instantiate_inline_workflow_template.return_value.operation.name = (
operation_name
)
op = DataprocInstantiateInlineWorkflowTemplateOperator(
task_id=TASK_ID,
template={},
region=GCP_REGION,
project_id=GCP_PROJECT,
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
cancel_on_kill=False,
)

op.execute(context=mock.MagicMock())

op.on_kill()
mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_not_called()

op.cancel_on_kill = True
op.on_kill()
mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_called_once_with(
name=operation_name
)


@pytest.mark.need_serialized_dag
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
Expand Down