From ae08ee9e315a4da3674c0493b4c89c5a2a2fa8da Mon Sep 17 00:00:00 2001 From: Heorhi Parkhomenka Date: Mon, 19 Sep 2022 14:18:03 +0200 Subject: [PATCH 1/5] Change dataprep system tests assets --- .../cloud/example_dags/example_dataprep.py | 79 ------ .../providers/google/cloud/hooks/dataprep.py | 82 +++++- .../providers/google/cloud/links/dataprep.py | 61 ++++ .../google/cloud/operators/dataprep.py | 201 +++++++++++++- .../google/cloud/sensors/dataprep.py | 51 ++++ airflow/providers/google/provider.yaml | 5 + .../operators/cloud/dataprep.rst | 77 ++++- .../google/cloud/hooks/test_dataprep.py | 262 +++++++++++++++++- .../google/cloud/operators/test_dataprep.py | 216 ++++++++++++++- .../google/cloud/sensors/test_dataprep.py | 44 +++ .../google/cloud/dataprep/__init__.py | 0 .../google/cloud/dataprep/example_dataprep.py | 175 ++++++++++++ 12 files changed, 1143 insertions(+), 110 deletions(-) delete mode 100644 airflow/providers/google/cloud/example_dags/example_dataprep.py create mode 100644 airflow/providers/google/cloud/links/dataprep.py create mode 100644 airflow/providers/google/cloud/sensors/dataprep.py create mode 100644 tests/providers/google/cloud/sensors/test_dataprep.py create mode 100644 tests/system/providers/google/cloud/dataprep/__init__.py create mode 100644 tests/system/providers/google/cloud/dataprep/example_dataprep.py diff --git a/airflow/providers/google/cloud/example_dags/example_dataprep.py b/airflow/providers/google/cloud/example_dags/example_dataprep.py deleted file mode 100644 index 6e295fac08397..0000000000000 --- a/airflow/providers/google/cloud/example_dags/example_dataprep.py +++ /dev/null @@ -1,79 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example Airflow DAG that shows how to use Google Dataprep. -""" -from __future__ import annotations - -import os -from datetime import datetime - -from airflow import models -from airflow.providers.google.cloud.operators.dataprep import ( - DataprepGetJobGroupOperator, - DataprepGetJobsForJobGroupOperator, - DataprepRunJobGroupOperator, -) - -DATAPREP_JOB_ID = int(os.environ.get("DATAPREP_JOB_ID", 12345677)) -DATAPREP_JOB_RECIPE_ID = int(os.environ.get("DATAPREP_JOB_RECIPE_ID", 12345677)) -DATAPREP_BUCKET = os.environ.get("DATAPREP_BUCKET", "gs://INVALID BUCKET NAME/name@email.com") - -DATA = { - "wrangledDataset": {"id": DATAPREP_JOB_RECIPE_ID}, - "overrides": { - "execution": "dataflow", - "profiler": False, - "writesettings": [ - { - "path": DATAPREP_BUCKET, - "action": "create", - "format": "csv", - "compression": "none", - "header": False, - "asSingleFile": False, - } - ], - }, -} - - -with models.DAG( - "example_dataprep", - start_date=datetime(2021, 1, 1), # Override to match your needs - catchup=False, -) as dag: - # [START how_to_dataprep_run_job_group_operator] - run_job_group = DataprepRunJobGroupOperator(task_id="run_job_group", body_request=DATA) - # [END how_to_dataprep_run_job_group_operator] - - # [START how_to_dataprep_get_jobs_for_job_group_operator] - get_jobs_for_job_group = DataprepGetJobsForJobGroupOperator( - task_id="get_jobs_for_job_group", job_id=DATAPREP_JOB_ID - ) - # [END how_to_dataprep_get_jobs_for_job_group_operator] - - # [START how_to_dataprep_get_job_group_operator] - get_job_group = DataprepGetJobGroupOperator( - task_id="get_job_group", - job_group_id=DATAPREP_JOB_ID, - embed="", - include_deleted=False, - ) - # [END how_to_dataprep_get_job_group_operator] - - run_job_group >> [get_jobs_for_job_group, get_job_group] diff --git a/airflow/providers/google/cloud/hooks/dataprep.py b/airflow/providers/google/cloud/hooks/dataprep.py index 45261fed00bdf..c7cbc3b55157b 100644 --- a/airflow/providers/google/cloud/hooks/dataprep.py +++ b/airflow/providers/google/cloud/hooks/dataprep.py @@ -19,8 +19,9 @@ from __future__ import annotations import json -import os +from enum import Enum from typing import Any +from urllib.parse import urljoin import requests from requests import HTTPError @@ -43,6 +44,17 @@ def _get_field(extras: dict, field_name: str): return extras.get(prefixed_name) or None +class JobGroupStatuses(str, Enum): + """Types of job group run statuses.""" + + CREATED = "Created" + UNDEFINED = "undefined" + IN_PROGRESS = "InProgress" + COMPLETE = "Complete" + FAILED = "Failed" + CANCELED = "Canceled" + + class GoogleDataprepHook(BaseHook): """ Hook for connection with Dataprep API. @@ -82,7 +94,7 @@ def get_jobs_for_job_group(self, job_id: int) -> dict[str, Any]: :param job_id: The ID of the job that will be fetched """ endpoint_path = f"v4/jobGroups/{job_id}/jobs" - url: str = os.path.join(self._base_url, endpoint_path) + url: str = urljoin(self._base_url, endpoint_path) response = requests.get(url, headers=self._headers) self._raise_for_status(response) return response.json() @@ -99,7 +111,7 @@ def get_job_group(self, job_group_id: int, embed: str, include_deleted: bool) -> """ params: dict[str, Any] = {"embed": embed, "includeDeleted": include_deleted} endpoint_path = f"v4/jobGroups/{job_group_id}" - url: str = os.path.join(self._base_url, endpoint_path) + url: str = urljoin(self._base_url, endpoint_path) response = requests.get(url, headers=self._headers, params=params) self._raise_for_status(response) return response.json() @@ -115,11 +127,73 @@ def run_job_group(self, body_request: dict) -> dict[str, Any]: :param body_request: The identifier for the recipe you would like to run. """ endpoint_path = "v4/jobGroups" - url: str = os.path.join(self._base_url, endpoint_path) + url: str = urljoin(self._base_url, endpoint_path) + response = requests.post(url, headers=self._headers, data=json.dumps(body_request)) + self._raise_for_status(response) + return response.json() + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def copy_flow( + self, *, flow_id: int, name: str = "", description: str = "", copy_datasources: bool = False + ) -> dict: + """ + Create a copy of the provided flow id, as well as all contained recipes. + + :param flow_id: ID of the flow to be copied + :param name: Name for the copy of the flow + :param description: Description of the copy of the flow + :param copy_datasources: Bool value to define should copies of data inputs be made or not. + """ + endpoint_path = f"v4/flows/{flow_id}/copy" + url: str = urljoin(self._base_url, endpoint_path) + body_request = { + "name": name, + "description": description, + "copyDatasources": copy_datasources, + } + response = requests.post(url, headers=self._headers, data=json.dumps(body_request)) + self._raise_for_status(response) + return response.json() + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def delete_flow(self, *, flow_id: int) -> None: + """ + Delete the flow with the provided id. + + :param flow_id: ID of the flow to be copied + """ + endpoint_path = f"v4/flows/{flow_id}" + url: str = urljoin(self._base_url, endpoint_path) + response = requests.delete(url, headers=self._headers) + self._raise_for_status(response) + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def run_flow(self, *, flow_id: int, body_request: dict) -> dict: + """ + Runs the flow with the provided id copy of the provided flow id. + + :param flow_id: ID of the flow to be copied + :param body_request: Body of the POST request to be sent. + """ + endpoint = f"v4/flows/{flow_id}/run" + url: str = urljoin(self._base_url, endpoint) response = requests.post(url, headers=self._headers, data=json.dumps(body_request)) self._raise_for_status(response) return response.json() + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def get_job_group_status(self, *, job_group_id: int) -> JobGroupStatuses: + """ + Check the status of the Dataprep task to be finished. + + :param job_group_id: ID of the job group to check + """ + endpoint = f"/v4/jobGroups/{job_group_id}/status" + url: str = urljoin(self._base_url, endpoint) + response = requests.get(url, headers=self._headers) + self._raise_for_status(response) + return response.json() + def _raise_for_status(self, response: requests.models.Response) -> None: try: response.raise_for_status() diff --git a/airflow/providers/google/cloud/links/dataprep.py b/airflow/providers/google/cloud/links/dataprep.py new file mode 100644 index 0000000000000..38d900f0b44fc --- /dev/null +++ b/airflow/providers/google/cloud/links/dataprep.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import TYPE_CHECKING + +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +BASE_LINK = "https://clouddataprep.com" +DATAPREP_FLOW_LINK = BASE_LINK + "/flows/{flow_id}?projectId={project_id}" +DATAPREP_JOB_GROUP_LINK = BASE_LINK + "/jobs/{job_group_id}?projectId={project_id}" + + +class DataprepFlowLink(BaseGoogleLink): + """Helper class for constructing Dataprep flow link.""" + + name = "Flow details page" + key = "dataprep_flow_page" + format_str = DATAPREP_FLOW_LINK + + @staticmethod + def persist(context: "Context", task_instance, project_id: str, flow_id: int): + task_instance.xcom_push( + context=context, + key=DataprepFlowLink.key, + value={"project_id": project_id, "flow_id": flow_id}, + ) + + +class DataprepJobGroupLink(BaseGoogleLink): + """Helper class for constructing Dataprep job group link.""" + + name = "Job group details page" + key = "dataprep_job_group_page" + format_str = DATAPREP_JOB_GROUP_LINK + + @staticmethod + def persist(context: "Context", task_instance, project_id: str, job_group_id: int): + task_instance.xcom_push( + context=context, + key=DataprepJobGroupLink.key, + value={ + "project_id": project_id, + "job_group_id": job_group_id, + }, + ) diff --git a/airflow/providers/google/cloud/operators/dataprep.py b/airflow/providers/google/cloud/operators/dataprep.py index ac62f0103281c..4a8cc28ac3ef4 100644 --- a/airflow/providers/google/cloud/operators/dataprep.py +++ b/airflow/providers/google/cloud/operators/dataprep.py @@ -18,10 +18,11 @@ """This module contains a Google Dataprep operator.""" from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Optional, Sequence, Union from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook +from airflow.providers.google.cloud.links.dataprep import DataprepFlowLink, DataprepJobGroupLink if TYPE_CHECKING: from airflow.utils.context import Context @@ -36,22 +37,28 @@ class DataprepGetJobsForJobGroupOperator(BaseOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:DataprepGetJobsForJobGroupOperator` - :param job_id The ID of the job that will be requests + :param job_group_id The ID of the job group that will be requests """ - template_fields: Sequence[str] = ("job_id",) + template_fields: Sequence[str] = ("job_group_id",) - def __init__(self, *, dataprep_conn_id: str = "dataprep_default", job_id: int, **kwargs) -> None: + def __init__( + self, + *, + dataprep_conn_id: str = "dataprep_default", + job_group_id: Union[int, str], + **kwargs, + ) -> None: super().__init__(**kwargs) self.dataprep_conn_id = (dataprep_conn_id,) - self.job_id = job_id + self.job_group_id = job_group_id def execute(self, context: Context) -> dict: - self.log.info("Fetching data for job with id: %d ...", self.job_id) + self.log.info("Fetching data for job with id: %d ...", self.job_group_id) hook = GoogleDataprepHook( dataprep_conn_id="dataprep_default", ) - response = hook.get_jobs_for_job_group(job_id=self.job_id) + response = hook.get_jobs_for_job_group(job_id=int(self.job_group_id)) return response @@ -65,33 +72,49 @@ class DataprepGetJobGroupOperator(BaseOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:DataprepGetJobGroupOperator` - :param job_group_id: The ID of the job that will be requests + :param job_group_id: The ID of the job group that will be requests :param embed: Comma-separated list of objects to pull in as part of the response :param include_deleted: if set to "true", will include deleted objects """ - template_fields: Sequence[str] = ("job_group_id", "embed") + template_fields: Sequence[str] = ( + "job_group_id", + "embed", + "project_id", + ) + operator_extra_links = (DataprepJobGroupLink(),) def __init__( self, *, dataprep_conn_id: str = "dataprep_default", - job_group_id: int, + project_id: Optional[str] = None, + job_group_id: Union[int, str], embed: str, include_deleted: bool, **kwargs, ) -> None: super().__init__(**kwargs) self.dataprep_conn_id: str = dataprep_conn_id + self.project_id = project_id self.job_group_id = job_group_id self.embed = embed self.include_deleted = include_deleted def execute(self, context: Context) -> dict: self.log.info("Fetching data for job with id: %d ...", self.job_group_id) + + if self.project_id: + DataprepJobGroupLink.persist( + context=context, + task_instance=self, + project_id=self.project_id, + job_group_id=int(self.job_group_id), + ) + hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) response = hook.get_job_group( - job_group_id=self.job_group_id, + job_group_id=int(self.job_group_id), embed=self.embed, include_deleted=self.include_deleted, ) @@ -115,14 +138,166 @@ class DataprepRunJobGroupOperator(BaseOperator): """ template_fields: Sequence[str] = ("body_request",) + operator_extra_links = (DataprepJobGroupLink(),) - def __init__(self, *, dataprep_conn_id: str = "dataprep_default", body_request: dict, **kwargs) -> None: + def __init__( + self, + *, + project_id: Optional[str] = None, + dataprep_conn_id: str = "dataprep_default", + body_request: dict, + **kwargs, + ) -> None: super().__init__(**kwargs) - self.body_request = body_request + self.project_id = project_id self.dataprep_conn_id = dataprep_conn_id + self.body_request = body_request def execute(self, context: Context) -> dict: self.log.info("Creating a job...") hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) response = hook.run_job_group(body_request=self.body_request) + + job_group_id = response.get('id') + if self.project_id and job_group_id: + DataprepJobGroupLink.persist( + context=context, + task_instance=self, + project_id=self.project_id, + job_group_id=int(job_group_id), + ) + + return response + + +class DataprepCopyFlowOperator(BaseOperator): + """ + Create a copy of the provided flow id, as well as all contained recipes. + + :param dataprep_conn_id: The Dataprep connection ID + :param flow_id: ID of the flow to be copied + :param name: Name for the copy of the flow + :param description: Description of the copy of the flow + :param copy_datasources: Bool value to define should the copy of data inputs be made or not. + """ + + template_fields: Sequence[str] = ( + 'flow_id', + 'name', + 'project_id', + 'description', + ) + operator_extra_links = (DataprepFlowLink(),) + + def __init__( + self, + *, + project_id: Optional[str] = None, + dataprep_conn_id: str = "dataprep_default", + flow_id: Union[int, str], + name: str = "", + description: str = "", + copy_datasources: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.dataprep_conn_id = dataprep_conn_id + self.flow_id = flow_id + self.name = name + self.description = description + self.copy_datasources = copy_datasources + + def execute(self, context: Context) -> dict: + self.log.info('Copying flow with id %d...', self.flow_id) + hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) + response = hook.copy_flow( + flow_id=int(self.flow_id), + name=self.name, + description=self.description, + copy_datasources=self.copy_datasources, + ) + + copied_flow_id = response.get('id') + if self.project_id and copied_flow_id: + DataprepFlowLink.persist( + context=context, + task_instance=self, + project_id=self.project_id, + flow_id=int(copied_flow_id), + ) + return response + + +class DataprepDeleteFlowOperator(BaseOperator): + """ + Delete the flow with provided id. + + :param dataprep_conn_id: The Dataprep connection ID + :param flow_id: ID of the flow to be copied + """ + + template_fields: Sequence[str] = ('flow_id',) + + def __init__( + self, + *, + dataprep_conn_id: str = "dataprep_default", + flow_id: Union[int, str], + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataprep_conn_id = dataprep_conn_id + self.flow_id = flow_id + + def execute(self, context: Context) -> None: + self.log.info("Start delete operation of the flow with id: %d...", self.flow_id) + hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) + hook.delete_flow(flow_id=int(self.flow_id)) + + +class DataprepRunFlowOperator(BaseOperator): + """ + Runs the flow with the provided id copy of the provided flow id. + + :param dataprep_conn_id: The Dataprep connection ID + :param flow_id: ID of the flow to be copied + :param body_request: Body of the POST request to be sent. + """ + + template_fields: Sequence[str] = ( + 'flow_id', + 'project_id', + ) + operator_extra_links = (DataprepJobGroupLink(),) + + def __init__( + self, + *, + project_id: Optional[str] = None, + flow_id: Union[int, str], + body_request: dict, + dataprep_conn_id: str = "dataprep_default", + **kwargs, + ): + super().__init__(**kwargs) + self.project_id = project_id + self.flow_id = flow_id + self.body_request = body_request + self.dataprep_conn_id = dataprep_conn_id + + def execute(self, context: Context) -> dict: + self.log.info("Running the flow with id: %d...", self.flow_id) + hooks = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) + response = hooks.run_flow(flow_id=int(self.flow_id), body_request=self.body_request) + + if self.project_id: + job_group_id = response['data'][0]['id'] + DataprepJobGroupLink.persist( + context=context, + task_instance=self, + project_id=self.project_id, + job_group_id=int(job_group_id), + ) + return response diff --git a/airflow/providers/google/cloud/sensors/dataprep.py b/airflow/providers/google/cloud/sensors/dataprep.py new file mode 100644 index 0000000000000..63401367b0a4b --- /dev/null +++ b/airflow/providers/google/cloud/sensors/dataprep.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Dataprep Job sensor.""" +from typing import TYPE_CHECKING, Sequence, Union + +from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook, JobGroupStatuses +from airflow.sensors.base import BaseSensorOperator, PokeReturnValue + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class DataprepJobGroupIsFinishedSensor(BaseSensorOperator): + """ + Check the status of the Dataprep task to be finished. + + :param job_group_id: ID of the job group to check + """ + + template_fields: Sequence[str] = ('job_group_id',) + + def __init__( + self, + *, + job_group_id: Union[int, str], + dataprep_conn_id: str = "dataprep_default", + **kwargs, + ): + super().__init__(**kwargs) + self.job_group_id = job_group_id + self.dataprep_conn_id = dataprep_conn_id + + def poke(self, context: "Context") -> Union[bool, PokeReturnValue]: + hooks = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) + status = hooks.get_job_group_status(job_group_id=int(self.job_group_id)) + return status != JobGroupStatuses.IN_PROGRESS diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 3533fba5f6b0a..268fc32bc9c57 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -609,6 +609,9 @@ sensors: - integration-name: Google Data Fusion python-modules: - airflow.providers.google.cloud.sensors.datafusion + - integration-name: Google Dataprep + python-modules: + - airflow.providers.google.cloud.sensors.dataprep - integration-name: Google Dataplex python-modules: - airflow.providers.google.cloud.sensors.dataplex @@ -985,6 +988,8 @@ extra-links: - airflow.providers.google.cloud.links.dataproc.DataprocListLink - airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreDetailedLink - airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreLink + - airflow.providers.google.cloud.links.dataprep.DataprepFlowLink + - airflow.providers.google.cloud.links.dataprep.DataprepJobGroupLink - airflow.providers.google.cloud.links.vertex_ai.VertexAIModelLink - airflow.providers.google.cloud.links.vertex_ai.VertexAIModelListLink - airflow.providers.google.cloud.links.vertex_ai.VertexAIModelExportLink diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataprep.rst b/docs/apache-airflow-providers-google/operators/cloud/dataprep.rst index 324a0b4875e93..282f16417f260 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataprep.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataprep.rst @@ -59,7 +59,7 @@ To get information about jobs within a Cloud Dataprep job use: Example usage: -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_dataprep.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataprep/example_dataprep.py :language: python :dedent: 4 :start-after: [START how_to_dataprep_run_job_group_operator] @@ -77,7 +77,7 @@ To get information about jobs within a Cloud Dataprep job use: Example usage: -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_dataprep.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataprep/example_dataprep.py :language: python :dedent: 4 :start-after: [START how_to_dataprep_get_jobs_for_job_group_operator] @@ -96,8 +96,79 @@ To get information about jobs within a Cloud Dataprep job use: Example usage: -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_dataprep.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataprep/example_dataprep.py :language: python :dedent: 4 :start-after: [START how_to_dataprep_get_job_group_operator] :end-before: [END how_to_dataprep_get_job_group_operator] + +Copy Flow +^^^^^^^^^^^^^ + +Operator task is to copy the flow. + +To get information about jobs within a Cloud Dataprep job use: +:class:`~airflow.providers.google.cloud.operators.dataprep.DataprepCopyFlowOperator` + +Example usage: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataprep/example_dataprep.py + :language: python + :dedent: 4 + :start-after: [START how_to_dataprep_copy_flow_operator] + :end-before: [END how_to_dataprep_get_job_group_operator] + +Run Flow +^^^^^^^^^^^^^ + +Operator task is to run the flow. +A flow is a container for wrangling logic which contains +imported datasets, recipe, output objects, and References. + +To get information about jobs within a Cloud Dataprep job use: +:class:`~airflow.providers.google.cloud.operators.dataprep.DataprepRunFlowOperator` + +Example usage: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataprep/example_dataprep.py + :language: python + :dedent: 4 + :start-after: [START how_to_dataprep_dataprep_run_flow_operator] + :end-before: [END how_to_dataprep_dataprep_run_flow_operator] + +Delete flow +^^^^^^^^^^^^^ + +Operator task is to delete the flow. +A flow is a container for wrangling logic which contains +imported datasets, recipe, output objects, and References. + +To get information about jobs within a Cloud Dataprep job use: +:class:`~airflow.providers.google.cloud.operators.dataprep.DataprepDeleteFlowOperator` + +Example usage: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataprep/example_dataprep.py + :language: python + :dedent: 4 + :start-after: [START how_to_dataprep_delete_flow_operator] + :end-before: [END how_to_dataprep_delete_flow_operator] + + +Check if Job Group is finished +^^^^^^^^^^^^^ + +Sensor task is to tell the system when started job group is finished +no matter successfully or not. +A job group is a job that is executed from a specific node in a flow. + +To get information about jobs within a Cloud Dataprep job use: +:class:`~airflow.providers.google.cloud.sensors.dataprep.DataprepJobGroupIsFinishedSensor` + +Example usage: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataprep/example_dataprep.py + :language: python + :dedent: 4 + :start-after: [START how_to_dataprep_job_group_finished_sensor] + :end-before: [END how_to_dataprep_job_group_finished_sensor] diff --git a/tests/providers/google/cloud/hooks/test_dataprep.py b/tests/providers/google/cloud/hooks/test_dataprep.py index b5950e871a23f..44d1f88d61033 100644 --- a/tests/providers/google/cloud/hooks/test_dataprep.py +++ b/tests/providers/google/cloud/hooks/test_dataprep.py @@ -19,7 +19,7 @@ import json import os -from unittest import mock +from unittest import mock, TestCase from unittest.mock import patch import pytest @@ -35,7 +35,7 @@ EXTRA = {"token": TOKEN} EMBED = "" INCLUDE_DELETED = False -DATA = json.dumps({"wrangledDataset": {"id": RECIPE_ID}}) +DATA = {"wrangledDataset": {"id": RECIPE_ID}} URL = "https://api.clouddataprep.com/v4/jobGroups" @@ -151,7 +151,6 @@ def test_get_job_group_raise_error_after_five_calls(self, mock_get_request): @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post") def test_run_job_group_should_be_called_once_with_params(self, mock_get_request): - data = '"{\\"wrangledDataset\\": {\\"id\\": 1234567}}"' self.hook.run_job_group(body_request=DATA) mock_get_request.assert_called_once_with( f"{URL}", @@ -159,7 +158,7 @@ def test_run_job_group_should_be_called_once_with_params(self, mock_get_request) "Content-Type": "application/json", "Authorization": f"Bearer {TOKEN}", }, - data=data, + data=json.dumps(DATA), ) @patch( @@ -206,6 +205,60 @@ def test_run_job_group_raise_error_after_five_calls(self, mock_get_request): assert "HTTPError" in str(ctx.value) assert mock_get_request.call_count == 5 + @patch("airflow.providers.google.cloud.hooks.dataprep.requests.get") + def test_get_job_group_status_should_be_called_once_with_params(self, mock_get_request): + self.hook.get_job_group_status(job_group_id=JOB_ID) + mock_get_request.assert_called_once_with( + f"{URL}/{JOB_ID}/status", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {TOKEN}", + }, + ) + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.get", + side_effect=[HTTPError(), mock.MagicMock()], + ) + def test_get_job_group_status_should_pass_after_retry(self, mock_get_request): + self.hook.get_job_group_status(job_group_id=JOB_ID) + assert mock_get_request.call_count == 2 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.get", + side_effect=[mock.MagicMock(), HTTPError()], + ) + def test_get_job_group_status_retry_after_success(self, mock_get_request): + self.hook.run_job_group.retry.sleep = mock.Mock() + self.hook.get_job_group_status(job_group_id=JOB_ID) + assert mock_get_request.call_count == 1 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.get", + side_effect=[ + HTTPError(), + HTTPError(), + HTTPError(), + HTTPError(), + mock.MagicMock(), + ], + ) + def test_get_job_group_status_four_errors(self, mock_get_request): + self.hook.run_job_group.retry.sleep = mock.Mock() + self.hook.get_job_group_status(job_group_id=JOB_ID) + assert mock_get_request.call_count == 5 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.get", + side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), HTTPError()], + ) + def test_get_job_group_status_five_calls(self, mock_get_request): + with pytest.raises(RetryError) as ctx: + self.hook.get_job_group_status.retry.sleep = mock.Mock() + self.hook.get_job_group_status(job_group_id=JOB_ID) + assert "HTTPError" in str(ctx.value) + assert mock_get_request.call_count == 5 + @pytest.mark.parametrize( "uri", [ @@ -218,3 +271,204 @@ def test_conn_extra_backcompat_prefix(self, uri): hook = GoogleDataprepHook("my_conn") assert hook._token == "abc" assert hook._base_url == "abc" + + +class TestGoogleDataprepFlowPathHooks(TestCase): + _url = "https://api.clouddataprep.com/v4/flows" + + def setUp(self) -> None: + self._flow_id = 1234567 + self._expected_copy_flow_hook_data = json.dumps( + { + "name": "", + "description": "", + "copyDatasources": False, + } + ) + self._expected_run_flow_hook_data = json.dumps({}) + with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn: + conn.return_value.extra_dejson = EXTRA + self.hook = GoogleDataprepHook(dataprep_conn_id="dataprep_default") + + @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post") + def test_copy_flow_should_be_called_once_with_params(self, mock_get_request): + self.hook.copy_flow( + flow_id=self._flow_id, + ) + mock_get_request.assert_called_once_with( + f"{self._url}/{self._flow_id}/copy", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {TOKEN}", + }, + data=self._expected_copy_flow_hook_data, + ) + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[HTTPError(), mock.MagicMock()], + ) + def test_copy_flow_should_pass_after_retry(self, mock_get_request): + self.hook.copy_flow(flow_id=self._flow_id) + assert mock_get_request.call_count == 2 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[mock.MagicMock(), HTTPError()], + ) + def test_copy_flow_should_not_retry_after_success(self, mock_get_request): + self.hook.copy_flow.retry.sleep = mock.Mock() + self.hook.copy_flow(flow_id=self._flow_id) + assert mock_get_request.call_count == 1 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[ + HTTPError(), + HTTPError(), + HTTPError(), + HTTPError(), + mock.MagicMock(), + ], + ) + def test_copy_flow_should_retry_after_four_errors(self, mock_get_request): + self.hook.copy_flow.retry.sleep = mock.Mock() + self.hook.copy_flow(flow_id=self._flow_id) + assert mock_get_request.call_count == 5 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), HTTPError()], + ) + def test_copy_flow_raise_error_after_five_calls(self, mock_get_request): + with pytest.raises(RetryError) as ctx: + self.hook.copy_flow.retry.sleep = mock.Mock() + self.hook.copy_flow(flow_id=self._flow_id) + assert "HTTPError" in str(ctx.value) + assert mock_get_request.call_count == 5 + + @patch("airflow.providers.google.cloud.hooks.dataprep.requests.delete") + def test_delete_flow_should_be_called_once_with_params(self, mock_get_request): + self.hook.delete_flow( + flow_id=self._flow_id, + ) + mock_get_request.assert_called_once_with( + f"{self._url}/{self._flow_id}", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {TOKEN}", + }, + ) + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.delete", + side_effect=[HTTPError(), mock.MagicMock()], + ) + def test_delete_flow_should_pass_after_retry(self, mock_get_request): + self.hook.delete_flow(flow_id=self._flow_id) + assert mock_get_request.call_count == 2 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.delete", + side_effect=[mock.MagicMock(), HTTPError()], + ) + def test_delete_flow_should_not_retry_after_success(self, mock_get_request): + self.hook.delete_flow.retry.sleep = mock.Mock() + self.hook.delete_flow(flow_id=self._flow_id) + assert mock_get_request.call_count == 1 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.delete", + side_effect=[ + HTTPError(), + HTTPError(), + HTTPError(), + HTTPError(), + mock.MagicMock(), + ], + ) + def test_delete_flow_should_retry_after_four_errors(self, mock_get_request): + self.hook.delete_flow.retry.sleep = mock.Mock() + self.hook.delete_flow(flow_id=self._flow_id) + assert mock_get_request.call_count == 5 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.delete", + side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), HTTPError()], + ) + def test_delete_flow_raise_error_after_five_calls(self, mock_get_request): + with pytest.raises(RetryError) as ctx: + self.hook.delete_flow.retry.sleep = mock.Mock() + self.hook.delete_flow(flow_id=self._flow_id) + assert "HTTPError" in str(ctx.value) + assert mock_get_request.call_count == 5 + + @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post") + def test_run_flow_should_be_called_once_with_params(self, mock_get_request): + self.hook.run_flow( + flow_id=self._flow_id, + body_request={}, + ) + mock_get_request.assert_called_once_with( + f"{self._url}/{self._flow_id}/run", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {TOKEN}", + }, + data=self._expected_run_flow_hook_data, + ) + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[HTTPError(), mock.MagicMock()], + ) + def test_run_flow_should_pass_after_retry(self, mock_get_request): + self.hook.run_flow( + flow_id=self._flow_id, + body_request={}, + ) + assert mock_get_request.call_count == 2 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[mock.MagicMock(), HTTPError()], + ) + def test_run_flow_should_not_retry_after_success(self, mock_get_request): + self.hook.run_flow.retry.sleep = mock.Mock() + self.hook.run_flow( + flow_id=self._flow_id, + body_request={}, + ) + assert mock_get_request.call_count == 1 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[ + HTTPError(), + HTTPError(), + HTTPError(), + HTTPError(), + mock.MagicMock(), + ], + ) + def test_run_flow_should_retry_after_four_errors(self, mock_get_request): + self.hook.run_flow.retry.sleep = mock.Mock() + self.hook.run_flow( + flow_id=self._flow_id, + body_request={}, + ) + assert mock_get_request.call_count == 5 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), HTTPError()], + ) + def test_run_flow_raise_error_after_five_calls(self, mock_get_request): + with pytest.raises(RetryError) as ctx: + self.hook.run_flow.retry.sleep = mock.Mock() + self.hook.run_flow( + flow_id=self._flow_id, + body_request={}, + ) + assert "HTTPError" in str(ctx.value) + assert mock_get_request.call_count == 5 diff --git a/tests/providers/google/cloud/operators/test_dataprep.py b/tests/providers/google/cloud/operators/test_dataprep.py index 94e0591ad7bab..12e5507ac1414 100644 --- a/tests/providers/google/cloud/operators/test_dataprep.py +++ b/tests/providers/google/cloud/operators/test_dataprep.py @@ -16,17 +16,24 @@ # specific language governing permissions and limitations # under the License. from __future__ import annotations +from unittest import mock -from unittest import TestCase, mock +import pytest from airflow.providers.google.cloud.operators.dataprep import ( + DataprepCopyFlowOperator, + DataprepDeleteFlowOperator, DataprepGetJobGroupOperator, DataprepGetJobsForJobGroupOperator, + DataprepRunFlowOperator, DataprepRunJobGroupOperator, ) +GCP_PROJECT_ID = "test-project-id" DATAPREP_CONN_ID = "dataprep_default" JOB_ID = 143 +FLOW_ID = 128754 +NEW_FLOW_ID = 1312 TASK_ID = "dataprep_job" INCLUDE_DELETED = False EMBED = "" @@ -51,40 +58,235 @@ } -class TestDataprepGetJobsForJobGroupOperator(TestCase): +class TestDataprepGetJobsForJobGroupOperator: @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") def test_execute(self, hook_mock): op = DataprepGetJobsForJobGroupOperator( - dataprep_conn_id=DATAPREP_CONN_ID, job_id=JOB_ID, task_id=TASK_ID + dataprep_conn_id=DATAPREP_CONN_ID, job_group_id=JOB_ID, task_id=TASK_ID ) op.execute(context={}) hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default") hook_mock.return_value.get_jobs_for_job_group.assert_called_once_with(job_id=JOB_ID) -class TestDataprepGetJobGroupOperator(TestCase): +class TestDataprepGetJobGroupOperator: @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") def test_execute(self, hook_mock): op = DataprepGetJobGroupOperator( dataprep_conn_id=DATAPREP_CONN_ID, + project_id=None, job_group_id=JOB_ID, embed=EMBED, include_deleted=INCLUDE_DELETED, task_id=TASK_ID, ) - op.execute(context={}) + op.execute(context=mock.MagicMock()) hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default") hook_mock.return_value.get_job_group.assert_called_once_with( job_group_id=JOB_ID, embed=EMBED, include_deleted=INCLUDE_DELETED ) + @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") + @mock.patch("airflow.providers.google.cloud.operators.dataprep.DataprepJobGroupLink") + @pytest.mark.parametrize( + 'provide_project_id, expected_call_count', + [ + (True, 1), + (False, 0), + ], + ) + def test_execute_with_project_id_will_persist_link_to_job_group( + self, + link_mock, + _, + provide_project_id, + expected_call_count, + ): + context = mock.MagicMock() + project_id = GCP_PROJECT_ID if provide_project_id else None + + op = DataprepGetJobGroupOperator( + task_id=TASK_ID, + project_id=project_id, + dataprep_conn_id=DATAPREP_CONN_ID, + job_group_id=JOB_ID, + embed=EMBED, + include_deleted=INCLUDE_DELETED, + ) + op.execute(context=context) + + assert link_mock.persist.call_count == expected_call_count + if provide_project_id: + link_mock.persist.assert_called_with( + context=context, + task_instance=op, + project_id=project_id, + job_group_id=JOB_ID, + ) -class TestDataprepRunJobGroupOperator(TestCase): + +class TestDataprepRunJobGroupOperator: @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") def test_execute(self, hook_mock): op = DataprepRunJobGroupOperator( - dataprep_conn_id=DATAPREP_CONN_ID, body_request=DATA, task_id=TASK_ID + dataprep_conn_id=DATAPREP_CONN_ID, + body_request=DATA, + task_id=TASK_ID, ) op.execute(context=None) hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default") hook_mock.return_value.run_job_group.assert_called_once_with(body_request=DATA) + + +class TestDataprepCopyFlowOperatorTest: + @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") + def test_execute_with_default_params(self, hook_mock): + op = DataprepCopyFlowOperator( + task_id=TASK_ID, + dataprep_conn_id=DATAPREP_CONN_ID, + flow_id=FLOW_ID, + ) + op.execute(context=mock.MagicMock()) + hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default") + hook_mock.return_value.copy_flow.assert_called_once_with( + flow_id=FLOW_ID, + name="", + description="", + copy_datasources=False, + ) + + @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") + def test_execute_with_specified_params(self, hook_mock): + op = DataprepCopyFlowOperator( + task_id=TASK_ID, + dataprep_conn_id=DATAPREP_CONN_ID, + flow_id=FLOW_ID, + name="specified name", + description="specified description", + copy_datasources=True, + ) + op.execute(context=mock.MagicMock()) + hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default") + hook_mock.return_value.copy_flow.assert_called_once_with( + flow_id=FLOW_ID, name="specified name", description="specified description", copy_datasources=True + ) + + @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") + def test_execute_with_templated_params(self, _, create_task_instance_of_operator): + dag_id = 'test_execute_with_templated_params' + ti = create_task_instance_of_operator( + DataprepCopyFlowOperator, + dag_id=dag_id, + project_id='{{ dag.dag_id }}', + task_id=TASK_ID, + flow_id='{{ dag.dag_id }}', + name='{{ dag.dag_id }}', + description='{{ dag.dag_id }}', + ) + ti.render_templates() + assert dag_id == ti.task.project_id + assert dag_id == ti.task.flow_id + assert dag_id == ti.task.name + assert dag_id == ti.task.description + + @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") + @mock.patch("airflow.providers.google.cloud.operators.dataprep.DataprepFlowLink") + @pytest.mark.parametrize( + 'provide_project_id, expected_call_count', + [ + (True, 1), + (False, 0), + ], + ) + def test_execute_with_project_id_will_persist_link_to_flow( + self, + link_mock, + hook_mock, + provide_project_id, + expected_call_count, + ): + hook_mock.return_value.copy_flow.return_value = {'id': NEW_FLOW_ID} + context = mock.MagicMock() + project_id = GCP_PROJECT_ID if provide_project_id else None + + op = DataprepCopyFlowOperator( + task_id=TASK_ID, + project_id=project_id, + dataprep_conn_id=DATAPREP_CONN_ID, + flow_id=FLOW_ID, + name="specified name", + description="specified description", + copy_datasources=True, + ) + op.execute(context=context) + + assert link_mock.persist.call_count == expected_call_count + if provide_project_id: + link_mock.persist.assert_called_with( + context=context, + task_instance=op, + project_id=project_id, + flow_id=NEW_FLOW_ID, + ) + + +class TestDataprepDeleteFlowOperator: + @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") + def test_execute(self, hook_mock): + op = DataprepDeleteFlowOperator( + task_id=TASK_ID, + dataprep_conn_id=DATAPREP_CONN_ID, + flow_id=FLOW_ID, + ) + op.execute(context=mock.MagicMock()) + hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default") + hook_mock.return_value.delete_flow.assert_called_once_with( + flow_id=FLOW_ID, + ) + + @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") + def test_execute_with_template_params(self, _, create_task_instance_of_operator): + dag_id = 'test_execute_delete_flow_with_template' + ti = create_task_instance_of_operator( + DataprepDeleteFlowOperator, + dag_id=dag_id, + task_id=TASK_ID, + flow_id="{{ dag.dag_id }}", + ) + ti.render_templates() + assert dag_id == ti.task.flow_id + + +class TestDataprepRunFlowOperator: + @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") + def test_execute(self, hook_mock): + op = DataprepRunFlowOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT_ID, + dataprep_conn_id=DATAPREP_CONN_ID, + flow_id=FLOW_ID, + body_request={}, + ) + op.execute(context=mock.MagicMock()) + hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default") + hook_mock.return_value.run_flow.assert_called_once_with( + flow_id=FLOW_ID, + body_request={}, + ) + + @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") + def test_execute_with_template_params(self, _, create_task_instance_of_operator): + dag_id = 'test_execute_run_flow_with_template' + ti = create_task_instance_of_operator( + DataprepRunFlowOperator, + dag_id=dag_id, + task_id=TASK_ID, + project_id="{{ dag.dag_id }}", + flow_id="{{ dag.dag_id }}", + body_request={}, + ) + + ti.render_templates() + + assert dag_id == ti.task.project_id + assert dag_id == ti.task.flow_id diff --git a/tests/providers/google/cloud/sensors/test_dataprep.py b/tests/providers/google/cloud/sensors/test_dataprep.py new file mode 100644 index 0000000000000..c800f55ddf7e7 --- /dev/null +++ b/tests/providers/google/cloud/sensors/test_dataprep.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest import mock + +from airflow.providers.google.cloud.hooks.dataprep import JobGroupStatuses +from airflow.providers.google.cloud.sensors.dataprep import DataprepJobGroupIsFinishedSensor + +JOB_GROUP_ID = 1312 + + +class TestDataprepJobGroupIsFinishedSensor: + @mock.patch("airflow.providers.google.cloud.sensors.dataprep.GoogleDataprepHook") + def test_passing_arguments_to_hook(self, hook_mock): + sensor = DataprepJobGroupIsFinishedSensor( + task_id='check_job_group_finished', + job_group_id=JOB_GROUP_ID, + ) + + hook_mock.return_value.get_job_group_status.return_value = JobGroupStatuses.COMPLETE + is_job_group_finished = sensor.poke(context=mock.MagicMock()) + + assert is_job_group_finished + + hook_mock.assert_called_once_with( + dataprep_conn_id='dataprep_default', + ) + hook_mock.return_value.get_job_group_status.assert_called_once_with( + job_group_id=JOB_GROUP_ID, + ) diff --git a/tests/system/providers/google/cloud/dataprep/__init__.py b/tests/system/providers/google/cloud/dataprep/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/system/providers/google/cloud/dataprep/example_dataprep.py b/tests/system/providers/google/cloud/dataprep/example_dataprep.py new file mode 100644 index 0000000000000..f8fe8ff16868d --- /dev/null +++ b/tests/system/providers/google/cloud/dataprep/example_dataprep.py @@ -0,0 +1,175 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that shows how to use Google Dataprep. +""" +import os +from datetime import datetime + +from airflow import models +from airflow.providers.google.cloud.operators.dataprep import ( + DataprepCopyFlowOperator, + DataprepDeleteFlowOperator, + DataprepGetJobGroupOperator, + DataprepGetJobsForJobGroupOperator, + DataprepRunFlowOperator, + DataprepRunJobGroupOperator, +) +from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator +from airflow.providers.google.cloud.sensors.dataprep import DataprepJobGroupIsFinishedSensor +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get('SYSTEM_TESTS_ENV_ID') +DAG_ID = "example_dataprep" + +GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID') +GCS_BUCKET_NAME = f"dataprep-bucket-heorhi-{DAG_ID}-{ENV_ID}" +GCS_BUCKET_PATH = f"gs://{GCS_BUCKET_NAME}/task_results/" + +FLOW_ID = os.environ.get('FLOW_ID', 1) +RECIPE_ID = os.environ.get('RECIPE_ID') +RECIPE_NAME = os.environ.get('RECIPE_NAME') +WRITE_SETTINGS = ( + { + "writesettings": [ + { + "path": GCS_BUCKET_PATH, + "action": "create", + "format": "csv", + } + ], + }, +) + +with models.DAG( + DAG_ID, + schedule_interval="@once", + start_date=datetime(2021, 1, 1), # Override to match your needs + catchup=False, + tags=['example', 'dataprep'], + render_template_as_native_obj=True, +) as dag: + # [START how_to_gcs_create_bucket_operator] + create_bucket_task = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=GCS_BUCKET_NAME, + project_id=GCP_PROJECT_ID, + ) + # [END how_to_gcs_create_bucket_operator] + + # [START how_to_dataprep_run_job_group_operator] + run_job_group_task = DataprepRunJobGroupOperator( + task_id='run_job_group', + project_id=GCP_PROJECT_ID, + body_request={ + "wrangledDataset": {"id": RECIPE_ID}, + "overrides": WRITE_SETTINGS, + }, + ) + # [END how_to_dataprep_run_job_group_operator] + + # [START how_to_dataprep_copy_flow_operator] + copy_task = DataprepCopyFlowOperator( + task_id="copy_flow", + project_id=GCP_PROJECT_ID, + flow_id=FLOW_ID, + name=f'dataprep_example_flow_{DAG_ID}_{ENV_ID}', + ) + # [END how_to_dataprep_copy_flow_operator] + + # [START how_to_dataprep_dataprep_run_flow_operator] + run_flow_task = DataprepRunFlowOperator( + task_id="run_flow", + project_id=GCP_PROJECT_ID, + flow_id="{{ task_instance.xcom_pull('copy_flow')['id'] }}", + body_request={ + "overrides": { + RECIPE_NAME: WRITE_SETTINGS, + }, + }, + ) + # [END how_to_dataprep_dataprep_run_flow_operator] + + # [START how_to_dataprep_get_job_group_operator] + get_job_group_task = DataprepGetJobGroupOperator( + task_id='get_job_group', + project_id=GCP_PROJECT_ID, + job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id'] }}", + embed="", + include_deleted=False, + ) + + # [START how_to_dataprep_get_jobs_for_job_group_operator] + get_jobs_for_job_group_task = DataprepGetJobsForJobGroupOperator( + task_id="get_jobs_for_job_group", + job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id'] }}", + ) + # [END how_to_dataprep_get_jobs_for_job_group_operator] + + # [START how_to_dataprep_job_group_finished_sensor] + check_flow_status_sensor = DataprepJobGroupIsFinishedSensor( + task_id="check_flow_status", + job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id'] }}", + ) + # [END how_to_dataprep_job_group_finished_sensor] + + # [START how_to_dataprep_job_group_finished_sensor] + check_job_group_status_sensor = DataprepJobGroupIsFinishedSensor( + task_id="check_job_group_status", + job_group_id="{{ task_instance.xcom_pull('run_job_group')['id'] }}", + ) + # [END how_to_dataprep_job_group_finished_sensor] + + # [START how_to_dataprep_delete_flow_operator] + delete_flow_task = DataprepDeleteFlowOperator( + task_id="delete_flow", + flow_id="{{ task_instance.xcom_pull('copy_flow')['id'] }}", + ) + # [END how_to_dataprep_delete_flow_operator] + + # [START gcs_delete_bucket_operator] + delete_bucket_task = GCSDeleteBucketOperator( + task_id="delete_bucket", + bucket_name=GCS_BUCKET_NAME, + trigger_rule=TriggerRule.ALL_DONE, + ) + # [END gcs_delete_bucket_operator] + + ( + # TEST SETUP + create_bucket_task + >> copy_task + # TEST BODY + >> [run_job_group_task, run_flow_task] + >> get_job_group_task + >> get_jobs_for_job_group_task + # TEST TEARDOWN + >> check_flow_status_sensor + >> [delete_flow_task, check_job_group_status_sensor] + >> delete_bucket_task + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) From 1b27da8b9909a31e9f12856972c91b7d6584c9ac Mon Sep 17 00:00:00 2001 From: Heorhi Parkhomenka Date: Tue, 20 Sep 2022 12:24:22 +0200 Subject: [PATCH 2/5] Change logic of base google class, refactor of system test, pre-commit changes --- airflow/providers/google/cloud/links/base.py | 2 +- .../providers/google/cloud/links/dataprep.py | 6 ++++-- .../google/cloud/operators/dataprep.py | 20 +++++++++---------- .../google/cloud/sensors/dataprep.py | 8 +++++--- .../google/cloud/operators/test_dataprep.py | 1 + .../google/cloud/sensors/test_dataprep.py | 2 ++ .../google/cloud/dataprep/__init__.py | 16 +++++++++++++++ .../google/cloud/dataprep/example_dataprep.py | 11 +++++----- 8 files changed, 44 insertions(+), 22 deletions(-) diff --git a/airflow/providers/google/cloud/links/base.py b/airflow/providers/google/cloud/links/base.py index 6539043a86bcd..755266758e8e8 100644 --- a/airflow/providers/google/cloud/links/base.py +++ b/airflow/providers/google/cloud/links/base.py @@ -45,6 +45,6 @@ def get_link( conf = XCom.get_value(key=self.key, ti_key=ti_key) if not conf: return "" - if self.format_str.startswith(BASE_LINK): + if self.format_str.startswith("http"): return self.format_str.format(**conf) return BASE_LINK + self.format_str.format(**conf) diff --git a/airflow/providers/google/cloud/links/dataprep.py b/airflow/providers/google/cloud/links/dataprep.py index 38d900f0b44fc..66caf1cfe8933 100644 --- a/airflow/providers/google/cloud/links/dataprep.py +++ b/airflow/providers/google/cloud/links/dataprep.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from typing import TYPE_CHECKING from airflow.providers.google.cloud.links.base import BaseGoogleLink @@ -34,7 +36,7 @@ class DataprepFlowLink(BaseGoogleLink): format_str = DATAPREP_FLOW_LINK @staticmethod - def persist(context: "Context", task_instance, project_id: str, flow_id: int): + def persist(context: Context, task_instance, project_id: str, flow_id: int): task_instance.xcom_push( context=context, key=DataprepFlowLink.key, @@ -50,7 +52,7 @@ class DataprepJobGroupLink(BaseGoogleLink): format_str = DATAPREP_JOB_GROUP_LINK @staticmethod - def persist(context: "Context", task_instance, project_id: str, job_group_id: int): + def persist(context: Context, task_instance, project_id: str, job_group_id: int): task_instance.xcom_push( context=context, key=DataprepJobGroupLink.key, diff --git a/airflow/providers/google/cloud/operators/dataprep.py b/airflow/providers/google/cloud/operators/dataprep.py index 4a8cc28ac3ef4..41752aa40b7a2 100644 --- a/airflow/providers/google/cloud/operators/dataprep.py +++ b/airflow/providers/google/cloud/operators/dataprep.py @@ -18,7 +18,7 @@ """This module contains a Google Dataprep operator.""" from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook @@ -46,7 +46,7 @@ def __init__( self, *, dataprep_conn_id: str = "dataprep_default", - job_group_id: Union[int, str], + job_group_id: int | str, **kwargs, ) -> None: super().__init__(**kwargs) @@ -88,8 +88,8 @@ def __init__( self, *, dataprep_conn_id: str = "dataprep_default", - project_id: Optional[str] = None, - job_group_id: Union[int, str], + project_id: str | None = None, + job_group_id: int | str, embed: str, include_deleted: bool, **kwargs, @@ -143,7 +143,7 @@ class DataprepRunJobGroupOperator(BaseOperator): def __init__( self, *, - project_id: Optional[str] = None, + project_id: str | None = None, dataprep_conn_id: str = "dataprep_default", body_request: dict, **kwargs, @@ -192,9 +192,9 @@ class DataprepCopyFlowOperator(BaseOperator): def __init__( self, *, - project_id: Optional[str] = None, + project_id: str | None = None, dataprep_conn_id: str = "dataprep_default", - flow_id: Union[int, str], + flow_id: int | str, name: str = "", description: str = "", copy_datasources: bool = False, @@ -243,7 +243,7 @@ def __init__( self, *, dataprep_conn_id: str = "dataprep_default", - flow_id: Union[int, str], + flow_id: int | str, **kwargs, ) -> None: super().__init__(**kwargs) @@ -274,8 +274,8 @@ class DataprepRunFlowOperator(BaseOperator): def __init__( self, *, - project_id: Optional[str] = None, - flow_id: Union[int, str], + project_id: str | None = None, + flow_id: int | str, body_request: dict, dataprep_conn_id: str = "dataprep_default", **kwargs, diff --git a/airflow/providers/google/cloud/sensors/dataprep.py b/airflow/providers/google/cloud/sensors/dataprep.py index 63401367b0a4b..8a5a96da924a4 100644 --- a/airflow/providers/google/cloud/sensors/dataprep.py +++ b/airflow/providers/google/cloud/sensors/dataprep.py @@ -16,7 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains a Dataprep Job sensor.""" -from typing import TYPE_CHECKING, Sequence, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook, JobGroupStatuses from airflow.sensors.base import BaseSensorOperator, PokeReturnValue @@ -37,7 +39,7 @@ class DataprepJobGroupIsFinishedSensor(BaseSensorOperator): def __init__( self, *, - job_group_id: Union[int, str], + job_group_id: int | str, dataprep_conn_id: str = "dataprep_default", **kwargs, ): @@ -45,7 +47,7 @@ def __init__( self.job_group_id = job_group_id self.dataprep_conn_id = dataprep_conn_id - def poke(self, context: "Context") -> Union[bool, PokeReturnValue]: + def poke(self, context: Context) -> bool | PokeReturnValue: hooks = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) status = hooks.get_job_group_status(job_group_id=int(self.job_group_id)) return status != JobGroupStatuses.IN_PROGRESS diff --git a/tests/providers/google/cloud/operators/test_dataprep.py b/tests/providers/google/cloud/operators/test_dataprep.py index 12e5507ac1414..0ba88f1ed364b 100644 --- a/tests/providers/google/cloud/operators/test_dataprep.py +++ b/tests/providers/google/cloud/operators/test_dataprep.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. from __future__ import annotations + from unittest import mock import pytest diff --git a/tests/providers/google/cloud/sensors/test_dataprep.py b/tests/providers/google/cloud/sensors/test_dataprep.py index c800f55ddf7e7..f5b26fab658c1 100644 --- a/tests/providers/google/cloud/sensors/test_dataprep.py +++ b/tests/providers/google/cloud/sensors/test_dataprep.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from unittest import mock from airflow.providers.google.cloud.hooks.dataprep import JobGroupStatuses diff --git a/tests/system/providers/google/cloud/dataprep/__init__.py b/tests/system/providers/google/cloud/dataprep/__init__.py index e69de29bb2d1d..13a83393a9124 100644 --- a/tests/system/providers/google/cloud/dataprep/__init__.py +++ b/tests/system/providers/google/cloud/dataprep/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/system/providers/google/cloud/dataprep/example_dataprep.py b/tests/system/providers/google/cloud/dataprep/example_dataprep.py index f8fe8ff16868d..c1335ee690e41 100644 --- a/tests/system/providers/google/cloud/dataprep/example_dataprep.py +++ b/tests/system/providers/google/cloud/dataprep/example_dataprep.py @@ -17,6 +17,8 @@ """ Example Airflow DAG that shows how to use Google Dataprep. """ +from __future__ import annotations + import os from datetime import datetime @@ -40,7 +42,7 @@ GCS_BUCKET_NAME = f"dataprep-bucket-heorhi-{DAG_ID}-{ENV_ID}" GCS_BUCKET_PATH = f"gs://{GCS_BUCKET_NAME}/task_results/" -FLOW_ID = os.environ.get('FLOW_ID', 1) +FLOW_ID = os.environ.get('FLOW_ID', '') RECIPE_ID = os.environ.get('RECIPE_ID') RECIPE_NAME = os.environ.get('RECIPE_NAME') WRITE_SETTINGS = ( @@ -57,19 +59,17 @@ with models.DAG( DAG_ID, - schedule_interval="@once", + schedule="@once", start_date=datetime(2021, 1, 1), # Override to match your needs catchup=False, tags=['example', 'dataprep'], render_template_as_native_obj=True, ) as dag: - # [START how_to_gcs_create_bucket_operator] create_bucket_task = GCSCreateBucketOperator( task_id="create_bucket", bucket_name=GCS_BUCKET_NAME, project_id=GCP_PROJECT_ID, ) - # [END how_to_gcs_create_bucket_operator] # [START how_to_dataprep_run_job_group_operator] run_job_group_task = DataprepRunJobGroupOperator( @@ -140,14 +140,13 @@ flow_id="{{ task_instance.xcom_pull('copy_flow')['id'] }}", ) # [END how_to_dataprep_delete_flow_operator] + delete_flow_task.trigger_rule = TriggerRule.ALL_DONE - # [START gcs_delete_bucket_operator] delete_bucket_task = GCSDeleteBucketOperator( task_id="delete_bucket", bucket_name=GCS_BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE, ) - # [END gcs_delete_bucket_operator] ( # TEST SETUP From d57062e0c2018fb7a0802ed0a31d8b48b7d11976 Mon Sep 17 00:00:00 2001 From: Heorhi Parkhomenka Date: Fri, 23 Sep 2022 13:06:38 +0200 Subject: [PATCH 3/5] Docs fix --- .../operators/cloud/dataprep.rst | 10 +++++----- .../google/cloud/dataprep/example_dataprep.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataprep.rst b/docs/apache-airflow-providers-google/operators/cloud/dataprep.rst index 282f16417f260..4957235604ffc 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataprep.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataprep.rst @@ -103,7 +103,7 @@ Example usage: :end-before: [END how_to_dataprep_get_job_group_operator] Copy Flow -^^^^^^^^^^^^^ +^^^^^^^^^ Operator task is to copy the flow. @@ -116,10 +116,10 @@ Example usage: :language: python :dedent: 4 :start-after: [START how_to_dataprep_copy_flow_operator] - :end-before: [END how_to_dataprep_get_job_group_operator] + :end-before: [END how_to_dataprep_copy_flow_operator] Run Flow -^^^^^^^^^^^^^ +^^^^^^^^ Operator task is to run the flow. A flow is a container for wrangling logic which contains @@ -137,7 +137,7 @@ Example usage: :end-before: [END how_to_dataprep_dataprep_run_flow_operator] Delete flow -^^^^^^^^^^^^^ +^^^^^^^^^^^ Operator task is to delete the flow. A flow is a container for wrangling logic which contains @@ -156,7 +156,7 @@ Example usage: Check if Job Group is finished -^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Sensor task is to tell the system when started job group is finished no matter successfully or not. diff --git a/tests/system/providers/google/cloud/dataprep/example_dataprep.py b/tests/system/providers/google/cloud/dataprep/example_dataprep.py index c1335ee690e41..5ad55ddb3e9b1 100644 --- a/tests/system/providers/google/cloud/dataprep/example_dataprep.py +++ b/tests/system/providers/google/cloud/dataprep/example_dataprep.py @@ -112,6 +112,7 @@ embed="", include_deleted=False, ) + # [END how_to_dataprep_get_job_group_operator] # [START how_to_dataprep_get_jobs_for_job_group_operator] get_jobs_for_job_group_task = DataprepGetJobsForJobGroupOperator( From b85b06a9cce6d77b310196c335f7a6404473f081 Mon Sep 17 00:00:00 2001 From: Heorhi Parkhomenka Date: Tue, 4 Oct 2022 15:01:08 +0200 Subject: [PATCH 4/5] Fix provider.yaml build issue --- airflow/providers/google/cloud/sensors/dataprep.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/google/cloud/sensors/dataprep.py b/airflow/providers/google/cloud/sensors/dataprep.py index 8a5a96da924a4..4541a87989bce 100644 --- a/airflow/providers/google/cloud/sensors/dataprep.py +++ b/airflow/providers/google/cloud/sensors/dataprep.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Sequence from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook, JobGroupStatuses -from airflow.sensors.base import BaseSensorOperator, PokeReturnValue +from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: from airflow.utils.context import Context @@ -47,7 +47,7 @@ def __init__( self.job_group_id = job_group_id self.dataprep_conn_id = dataprep_conn_id - def poke(self, context: Context) -> bool | PokeReturnValue: + def poke(self, context: Context) -> bool: hooks = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) status = hooks.get_job_group_status(job_group_id=int(self.job_group_id)) return status != JobGroupStatuses.IN_PROGRESS From c412529acc9623140d0643853ef3eca868c40ad7 Mon Sep 17 00:00:00 2001 From: Heorhi Parkhomenka Date: Tue, 25 Oct 2022 16:18:45 +0200 Subject: [PATCH 5/5] Reformat files with latest black and flake8 config --- .../google/cloud/operators/dataprep.py | 22 +++++++++---------- .../google/cloud/sensors/dataprep.py | 2 +- .../google/cloud/hooks/test_dataprep.py | 2 +- .../google/cloud/operators/test_dataprep.py | 20 ++++++++--------- .../google/cloud/sensors/test_dataprep.py | 4 ++-- .../google/cloud/dataprep/example_dataprep.py | 18 +++++++-------- 6 files changed, 34 insertions(+), 34 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataprep.py b/airflow/providers/google/cloud/operators/dataprep.py index 41752aa40b7a2..61340b07473f3 100644 --- a/airflow/providers/google/cloud/operators/dataprep.py +++ b/airflow/providers/google/cloud/operators/dataprep.py @@ -158,7 +158,7 @@ def execute(self, context: Context) -> dict: hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) response = hook.run_job_group(body_request=self.body_request) - job_group_id = response.get('id') + job_group_id = response.get("id") if self.project_id and job_group_id: DataprepJobGroupLink.persist( context=context, @@ -182,10 +182,10 @@ class DataprepCopyFlowOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'flow_id', - 'name', - 'project_id', - 'description', + "flow_id", + "name", + "project_id", + "description", ) operator_extra_links = (DataprepFlowLink(),) @@ -209,7 +209,7 @@ def __init__( self.copy_datasources = copy_datasources def execute(self, context: Context) -> dict: - self.log.info('Copying flow with id %d...', self.flow_id) + self.log.info("Copying flow with id %d...", self.flow_id) hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id) response = hook.copy_flow( flow_id=int(self.flow_id), @@ -218,7 +218,7 @@ def execute(self, context: Context) -> dict: copy_datasources=self.copy_datasources, ) - copied_flow_id = response.get('id') + copied_flow_id = response.get("id") if self.project_id and copied_flow_id: DataprepFlowLink.persist( context=context, @@ -237,7 +237,7 @@ class DataprepDeleteFlowOperator(BaseOperator): :param flow_id: ID of the flow to be copied """ - template_fields: Sequence[str] = ('flow_id',) + template_fields: Sequence[str] = ("flow_id",) def __init__( self, @@ -266,8 +266,8 @@ class DataprepRunFlowOperator(BaseOperator): """ template_fields: Sequence[str] = ( - 'flow_id', - 'project_id', + "flow_id", + "project_id", ) operator_extra_links = (DataprepJobGroupLink(),) @@ -292,7 +292,7 @@ def execute(self, context: Context) -> dict: response = hooks.run_flow(flow_id=int(self.flow_id), body_request=self.body_request) if self.project_id: - job_group_id = response['data'][0]['id'] + job_group_id = response["data"][0]["id"] DataprepJobGroupLink.persist( context=context, task_instance=self, diff --git a/airflow/providers/google/cloud/sensors/dataprep.py b/airflow/providers/google/cloud/sensors/dataprep.py index 4541a87989bce..d30f6e18e872c 100644 --- a/airflow/providers/google/cloud/sensors/dataprep.py +++ b/airflow/providers/google/cloud/sensors/dataprep.py @@ -34,7 +34,7 @@ class DataprepJobGroupIsFinishedSensor(BaseSensorOperator): :param job_group_id: ID of the job group to check """ - template_fields: Sequence[str] = ('job_group_id',) + template_fields: Sequence[str] = ("job_group_id",) def __init__( self, diff --git a/tests/providers/google/cloud/hooks/test_dataprep.py b/tests/providers/google/cloud/hooks/test_dataprep.py index 44d1f88d61033..a13369f734b42 100644 --- a/tests/providers/google/cloud/hooks/test_dataprep.py +++ b/tests/providers/google/cloud/hooks/test_dataprep.py @@ -19,7 +19,7 @@ import json import os -from unittest import mock, TestCase +from unittest import TestCase, mock from unittest.mock import patch import pytest diff --git a/tests/providers/google/cloud/operators/test_dataprep.py b/tests/providers/google/cloud/operators/test_dataprep.py index 0ba88f1ed364b..08237d0d5193a 100644 --- a/tests/providers/google/cloud/operators/test_dataprep.py +++ b/tests/providers/google/cloud/operators/test_dataprep.py @@ -90,7 +90,7 @@ def test_execute(self, hook_mock): @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") @mock.patch("airflow.providers.google.cloud.operators.dataprep.DataprepJobGroupLink") @pytest.mark.parametrize( - 'provide_project_id, expected_call_count', + "provide_project_id, expected_call_count", [ (True, 1), (False, 0), @@ -174,15 +174,15 @@ def test_execute_with_specified_params(self, hook_mock): @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") def test_execute_with_templated_params(self, _, create_task_instance_of_operator): - dag_id = 'test_execute_with_templated_params' + dag_id = "test_execute_with_templated_params" ti = create_task_instance_of_operator( DataprepCopyFlowOperator, dag_id=dag_id, - project_id='{{ dag.dag_id }}', + project_id="{{ dag.dag_id }}", task_id=TASK_ID, - flow_id='{{ dag.dag_id }}', - name='{{ dag.dag_id }}', - description='{{ dag.dag_id }}', + flow_id="{{ dag.dag_id }}", + name="{{ dag.dag_id }}", + description="{{ dag.dag_id }}", ) ti.render_templates() assert dag_id == ti.task.project_id @@ -193,7 +193,7 @@ def test_execute_with_templated_params(self, _, create_task_instance_of_operator @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") @mock.patch("airflow.providers.google.cloud.operators.dataprep.DataprepFlowLink") @pytest.mark.parametrize( - 'provide_project_id, expected_call_count', + "provide_project_id, expected_call_count", [ (True, 1), (False, 0), @@ -206,7 +206,7 @@ def test_execute_with_project_id_will_persist_link_to_flow( provide_project_id, expected_call_count, ): - hook_mock.return_value.copy_flow.return_value = {'id': NEW_FLOW_ID} + hook_mock.return_value.copy_flow.return_value = {"id": NEW_FLOW_ID} context = mock.MagicMock() project_id = GCP_PROJECT_ID if provide_project_id else None @@ -247,7 +247,7 @@ def test_execute(self, hook_mock): @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") def test_execute_with_template_params(self, _, create_task_instance_of_operator): - dag_id = 'test_execute_delete_flow_with_template' + dag_id = "test_execute_delete_flow_with_template" ti = create_task_instance_of_operator( DataprepDeleteFlowOperator, dag_id=dag_id, @@ -277,7 +277,7 @@ def test_execute(self, hook_mock): @mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook") def test_execute_with_template_params(self, _, create_task_instance_of_operator): - dag_id = 'test_execute_run_flow_with_template' + dag_id = "test_execute_run_flow_with_template" ti = create_task_instance_of_operator( DataprepRunFlowOperator, dag_id=dag_id, diff --git a/tests/providers/google/cloud/sensors/test_dataprep.py b/tests/providers/google/cloud/sensors/test_dataprep.py index f5b26fab658c1..7ea1816aa79d4 100644 --- a/tests/providers/google/cloud/sensors/test_dataprep.py +++ b/tests/providers/google/cloud/sensors/test_dataprep.py @@ -29,7 +29,7 @@ class TestDataprepJobGroupIsFinishedSensor: @mock.patch("airflow.providers.google.cloud.sensors.dataprep.GoogleDataprepHook") def test_passing_arguments_to_hook(self, hook_mock): sensor = DataprepJobGroupIsFinishedSensor( - task_id='check_job_group_finished', + task_id="check_job_group_finished", job_group_id=JOB_GROUP_ID, ) @@ -39,7 +39,7 @@ def test_passing_arguments_to_hook(self, hook_mock): assert is_job_group_finished hook_mock.assert_called_once_with( - dataprep_conn_id='dataprep_default', + dataprep_conn_id="dataprep_default", ) hook_mock.return_value.get_job_group_status.assert_called_once_with( job_group_id=JOB_GROUP_ID, diff --git a/tests/system/providers/google/cloud/dataprep/example_dataprep.py b/tests/system/providers/google/cloud/dataprep/example_dataprep.py index 5ad55ddb3e9b1..9f478a5f0be5b 100644 --- a/tests/system/providers/google/cloud/dataprep/example_dataprep.py +++ b/tests/system/providers/google/cloud/dataprep/example_dataprep.py @@ -35,16 +35,16 @@ from airflow.providers.google.cloud.sensors.dataprep import DataprepJobGroupIsFinishedSensor from airflow.utils.trigger_rule import TriggerRule -ENV_ID = os.environ.get('SYSTEM_TESTS_ENV_ID') +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "example_dataprep" -GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID') +GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") GCS_BUCKET_NAME = f"dataprep-bucket-heorhi-{DAG_ID}-{ENV_ID}" GCS_BUCKET_PATH = f"gs://{GCS_BUCKET_NAME}/task_results/" -FLOW_ID = os.environ.get('FLOW_ID', '') -RECIPE_ID = os.environ.get('RECIPE_ID') -RECIPE_NAME = os.environ.get('RECIPE_NAME') +FLOW_ID = os.environ.get("FLOW_ID", "") +RECIPE_ID = os.environ.get("RECIPE_ID") +RECIPE_NAME = os.environ.get("RECIPE_NAME") WRITE_SETTINGS = ( { "writesettings": [ @@ -62,7 +62,7 @@ schedule="@once", start_date=datetime(2021, 1, 1), # Override to match your needs catchup=False, - tags=['example', 'dataprep'], + tags=["example", "dataprep"], render_template_as_native_obj=True, ) as dag: create_bucket_task = GCSCreateBucketOperator( @@ -73,7 +73,7 @@ # [START how_to_dataprep_run_job_group_operator] run_job_group_task = DataprepRunJobGroupOperator( - task_id='run_job_group', + task_id="run_job_group", project_id=GCP_PROJECT_ID, body_request={ "wrangledDataset": {"id": RECIPE_ID}, @@ -87,7 +87,7 @@ task_id="copy_flow", project_id=GCP_PROJECT_ID, flow_id=FLOW_ID, - name=f'dataprep_example_flow_{DAG_ID}_{ENV_ID}', + name=f"dataprep_example_flow_{DAG_ID}_{ENV_ID}", ) # [END how_to_dataprep_copy_flow_operator] @@ -106,7 +106,7 @@ # [START how_to_dataprep_get_job_group_operator] get_job_group_task = DataprepGetJobGroupOperator( - task_id='get_job_group', + task_id="get_job_group", project_id=GCP_PROJECT_ID, job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id'] }}", embed="",