From b0d8f86eead12d07011e6e3bbfde3a03a18024a7 Mon Sep 17 00:00:00 2001 From: Beata Kossakowska Date: Mon, 3 Oct 2022 06:55:58 +0000 Subject: [PATCH] Cloud ML Engine operators assets (AIP-47) --- .../providers/google/cloud/links/mlengine.py | 140 ++++++++++++++++++ .../google/cloud/operators/mlengine.py | 138 ++++++++++++----- airflow/providers/google/provider.yaml | 5 + .../google/cloud/operators/test_mlengine.py | 82 ++-------- 4 files changed, 256 insertions(+), 109 deletions(-) create mode 100644 airflow/providers/google/cloud/links/mlengine.py diff --git a/airflow/providers/google/cloud/links/mlengine.py b/airflow/providers/google/cloud/links/mlengine.py new file mode 100644 index 0000000000000..bbfe0cc5385ce --- /dev/null +++ b/airflow/providers/google/cloud/links/mlengine.py @@ -0,0 +1,140 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google ML Engine links.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +MLENGINE_BASE_LINK = "https://console.cloud.google.com/ai-platform" +MLENGINE_MODEL_DETAILS_LINK = MLENGINE_BASE_LINK + "/models/{model_id}/versions?project={project_id}" +MLENGINE_MODEL_VERSION_DETAILS_LINK = ( + MLENGINE_BASE_LINK + "/models/{model_id}/versions/{version_id}/performance?project={project_id}" +) +MLENGINE_MODELS_LIST_LINK = MLENGINE_BASE_LINK + "/models/?project={project_id}" +MLENGINE_JOB_DETAILS_LINK = MLENGINE_BASE_LINK + "/jobs/{job_id}?project={project_id}" +MLENGINE_JOBS_LIST_LINK = MLENGINE_BASE_LINK + "/jobs?project={project_id}" + + +class MLEngineModelLink(BaseGoogleLink): + """Helper class for constructing ML Engine link""" + + name = "MLEngine Model" + key = "ml_engine_model" + format_str = MLENGINE_MODEL_DETAILS_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + model_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=MLEngineModelLink.key, + value={"model_id": model_id, "project_id": project_id}, + ) + + +class MLEngineModelsListLink(BaseGoogleLink): + """Helper class for constructing ML Engine link""" + + name = "MLEngine Models List" + key = "ml_engine_models_list" + format_str = MLENGINE_MODELS_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=MLEngineModelsListLink.key, + value={"project_id": project_id}, + ) + + +class MLEngineJobDetailsLink(BaseGoogleLink): + """Helper class for constructing ML Engine link""" + + name = "MLEngine Job Details" + key = "ml_engine_job_details" + format_str = MLENGINE_JOB_DETAILS_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + job_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=MLEngineJobDetailsLink.key, + value={"job_id": job_id, "project_id": project_id}, + ) + + +class MLEngineModelVersionDetailsLink(BaseGoogleLink): + """Helper class for constructing ML Engine link""" + + name = "MLEngine Version Details" + key = "ml_engine_version_details" + format_str = MLENGINE_MODEL_VERSION_DETAILS_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + model_id: str, + project_id: str, + version_id: str, + ): + task_instance.xcom_push( + context, + key=MLEngineModelVersionDetailsLink.key, + value={"model_id": model_id, "project_id": project_id, "version_id": version_id}, + ) + + +class MLEngineJobSListLink(BaseGoogleLink): + """Helper class for constructing ML Engine link""" + + name = "MLEngine Jobs List" + key = "ml_engine_jobs_list" + format_str = MLENGINE_JOBS_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=MLEngineJobSListLink.key, + value={"project_id": project_id}, + ) diff --git a/airflow/providers/google/cloud/operators/mlengine.py b/airflow/providers/google/cloud/operators/mlengine.py index f4acb7ae9b22a..7c258728bada6 100644 --- a/airflow/providers/google/cloud/operators/mlengine.py +++ b/airflow/providers/google/cloud/operators/mlengine.py @@ -16,20 +16,26 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud MLEngine operators.""" + from __future__ import annotations -import datetime import logging import re import warnings from typing import TYPE_CHECKING, Any, Sequence from airflow.exceptions import AirflowException -from airflow.models import BaseOperator, BaseOperatorLink, XCom +from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook +from airflow.providers.google.cloud.links.mlengine import ( + MLEngineJobDetailsLink, + MLEngineJobSListLink, + MLEngineModelLink, + MLEngineModelsListLink, + MLEngineModelVersionDetailsLink, +) if TYPE_CHECKING: - from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.context import Context @@ -396,6 +402,7 @@ class MLEngineCreateModelOperator(BaseOperator): '_model', '_impersonation_chain', ) + operator_extra_links = (MLEngineModelLink(),) def __init__( self, @@ -420,6 +427,16 @@ def execute(self, context: Context): delegate_to=self._delegate_to, impersonation_chain=self._impersonation_chain, ) + + project_id = self._project_id or hook.project_id + if project_id: + MLEngineModelLink.persist( + context=context, + task_instance=self, + project_id=project_id, + model_id=self._model['name'], + ) + return hook.create_model(project_id=self._project_id, model=self._model) @@ -456,6 +473,7 @@ class MLEngineGetModelOperator(BaseOperator): '_model_name', '_impersonation_chain', ) + operator_extra_links = (MLEngineModelLink(),) def __init__( self, @@ -480,6 +498,15 @@ def execute(self, context: Context): delegate_to=self._delegate_to, impersonation_chain=self._impersonation_chain, ) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineModelLink.persist( + context=context, + task_instance=self, + project_id=project_id, + model_id=self._model_name, + ) + return hook.get_model(project_id=self._project_id, model_name=self._model_name) @@ -519,6 +546,7 @@ class MLEngineDeleteModelOperator(BaseOperator): '_model_name', '_impersonation_chain', ) + operator_extra_links = (MLEngineModelsListLink(),) def __init__( self, @@ -546,6 +574,14 @@ def execute(self, context: Context): impersonation_chain=self._impersonation_chain, ) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineModelsListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + return hook.delete_model( project_id=self._project_id, model_name=self._model_name, delete_contents=self._delete_contents ) @@ -711,6 +747,7 @@ class MLEngineCreateVersionOperator(BaseOperator): '_version', '_impersonation_chain', ) + operator_extra_links = (MLEngineModelVersionDetailsLink(),) def __init__( self, @@ -747,6 +784,16 @@ def execute(self, context: Context): impersonation_chain=self._impersonation_chain, ) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineModelVersionDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + model_id=self._model_name, + version_id=self._version['name'], + ) + return hook.create_version( project_id=self._project_id, model_name=self._model_name, version_spec=self._version ) @@ -788,6 +835,7 @@ class MLEngineSetDefaultVersionOperator(BaseOperator): '_version_name', '_impersonation_chain', ) + operator_extra_links = (MLEngineModelVersionDetailsLink(),) def __init__( self, @@ -824,6 +872,16 @@ def execute(self, context: Context): impersonation_chain=self._impersonation_chain, ) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineModelVersionDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + model_id=self._model_name, + version_id=self._version_name, + ) + return hook.set_default_version( project_id=self._project_id, model_name=self._model_name, version_name=self._version_name ) @@ -863,6 +921,7 @@ class MLEngineListVersionsOperator(BaseOperator): '_model_name', '_impersonation_chain', ) + operator_extra_links = (MLEngineModelLink(),) def __init__( self, @@ -894,6 +953,15 @@ def execute(self, context: Context): impersonation_chain=self._impersonation_chain, ) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineModelLink.persist( + context=context, + task_instance=self, + project_id=project_id, + model_id=self._model_name, + ) + return hook.list_versions( project_id=self._project_id, model_name=self._model_name, @@ -936,6 +1004,7 @@ class MLEngineDeleteVersionOperator(BaseOperator): '_version_name', '_impersonation_chain', ) + operator_extra_links = (MLEngineModelLink(),) def __init__( self, @@ -972,40 +1041,20 @@ def execute(self, context: Context): impersonation_chain=self._impersonation_chain, ) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineModelLink.persist( + context=context, + task_instance=self, + project_id=project_id, + model_id=self._model_name, + ) + return hook.delete_version( project_id=self._project_id, model_name=self._model_name, version_name=self._version_name ) -class AIPlatformConsoleLink(BaseOperatorLink): - """Helper class for constructing AI Platform Console link.""" - - name = "AI Platform Console" - - def get_link( - self, - operator, - dttm: datetime.datetime | None = None, - ti_key: TaskInstanceKey | None = None, - ) -> str: - if ti_key is not None: - gcp_metadata_dict = XCom.get_value(key="gcp_metadata", ti_key=ti_key) - else: - assert dttm is not None - gcp_metadata_dict = XCom.get_one( - key="gcp_metadata", - dag_id=operator.dag.dag_id, - task_id=operator.task_id, - execution_date=dttm, - ) - if not gcp_metadata_dict: - return '' - job_id = gcp_metadata_dict['job_id'] - project_id = gcp_metadata_dict['project_id'] - console_link = f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}" - return console_link - - class MLEngineStartTrainingJobOperator(BaseOperator): """ Operator for launching a MLEngine training job. @@ -1087,8 +1136,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator): '_hyperparameters', '_impersonation_chain', ) - - operator_extra_links = (AIPlatformConsoleLink(),) + operator_extra_links = (MLEngineJobDetailsLink(),) def __init__( self, @@ -1238,11 +1286,14 @@ def check_existing_job(existing_job): self.log.error('MLEngine training job failed: %s', str(finished_training_job)) raise RuntimeError(finished_training_job['errorMessage']) - gcp_metadata = { - "job_id": job_id, - "project_id": self._project_id, - } - context['task_instance'].xcom_push("gcp_metadata", gcp_metadata) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineJobDetailsLink.persist( + context=context, + task_instance=self, + project_id=project_id, + job_id=job_id, + ) class MLEngineTrainingCancelJobOperator(BaseOperator): @@ -1273,6 +1324,7 @@ class MLEngineTrainingCancelJobOperator(BaseOperator): '_job_id', '_impersonation_chain', ) + operator_extra_links = (MLEngineJobSListLink(),) def __init__( self, @@ -1302,4 +1354,12 @@ def execute(self, context: Context): impersonation_chain=self._impersonation_chain, ) + project_id = self._project_id or hook.project_id + if project_id: + MLEngineJobSListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + hook.cancel_job(project_id=self._project_id, job_id=_normalize_mlengine_job_id(self._job_id)) diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 4be41370a5787..970827ea2abf2 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -1032,6 +1032,11 @@ extra-links: - airflow.providers.google.cloud.links.data_loss_prevention.CloudDLPInfoTypesListLink - airflow.providers.google.cloud.links.data_loss_prevention.CloudDLPInfoTypeDetailsLink - airflow.providers.google.cloud.links.data_loss_prevention.CloudDLPPossibleInfoTypesListLink + - airflow.providers.google.cloud.links.mlengine.MLEngineModelLink + - airflow.providers.google.cloud.links.mlengine.MLEngineModelsListLink + - airflow.providers.google.cloud.links.mlengine.MLEngineJobDetailsLink + - airflow.providers.google.cloud.links.mlengine.MLEngineJobSListLink + - airflow.providers.google.cloud.links.mlengine.MLEngineModelVersionDetailsLink - airflow.providers.google.common.links.storage.StorageLink - airflow.providers.google.common.links.storage.FileDetailsLink diff --git a/tests/providers/google/cloud/operators/test_mlengine.py b/tests/providers/google/cloud/operators/test_mlengine.py index c54969cbae907..f9b3c34301e55 100644 --- a/tests/providers/google/cloud/operators/test_mlengine.py +++ b/tests/providers/google/cloud/operators/test_mlengine.py @@ -26,8 +26,7 @@ from airflow.exceptions import AirflowException from airflow.models.dag import DAG -from airflow.providers.google.cloud.operators.mlengine import ( - AIPlatformConsoleLink, +from airflow.providers.google.cloud.operators.mlengine import ( # AIPlatformConsoleLink, MLEngineCreateModelOperator, MLEngineCreateVersionOperator, MLEngineDeleteModelOperator, @@ -41,7 +40,6 @@ MLEngineStartTrainingJobOperator, MLEngineTrainingCancelJobOperator, ) -from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils import timezone DEFAULT_DATE = timezone.datetime(2017, 6, 6) @@ -565,62 +563,6 @@ def test_failed_job_error(self, mock_hook): ) assert 'A failure message' == str(ctx.value) - @patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook') - def test_console_extra_link(self, mock_hook, create_task_instance_of_operator): - ti = create_task_instance_of_operator( - MLEngineStartTrainingJobOperator, - dag_id="test_console_extra_link", - execution_date=DEFAULT_DATE, - **self.TRAINING_DEFAULT_ARGS, - ) - - job_id = self.TRAINING_DEFAULT_ARGS['job_id'] - project_id = self.TRAINING_DEFAULT_ARGS['project_id'] - gcp_metadata = { - "job_id": job_id, - "project_id": project_id, - } - ti.xcom_push(key='gcp_metadata', value=gcp_metadata) - - assert ( - f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}" - == ti.task.get_extra_links(ti, AIPlatformConsoleLink.name) - ) - - @pytest.mark.need_serialized_dag - def test_console_extra_link_serialized_field(self, dag_maker, create_task_instance_of_operator): - ti = create_task_instance_of_operator( - MLEngineStartTrainingJobOperator, - dag_id="test_console_extra_link_serialized_field", - execution_date=DEFAULT_DATE, - **self.TRAINING_DEFAULT_ARGS, - ) - serialized_dag = dag_maker.get_serialized_data() - dag = SerializedDAG.from_dict(serialized_dag) - simple_task = dag.task_dict[self.TRAINING_DEFAULT_ARGS['task_id']] - - # Check Serialized version of operator link - assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ - {"airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink": {}} - ] - - # Check DeSerialized version of operator link - assert isinstance(list(simple_task.operator_extra_links)[0], AIPlatformConsoleLink) - - job_id = self.TRAINING_DEFAULT_ARGS['job_id'] - project_id = self.TRAINING_DEFAULT_ARGS['project_id'] - gcp_metadata = { - "job_id": job_id, - "project_id": project_id, - } - - ti.xcom_push(key='gcp_metadata', value=gcp_metadata) - - assert ( - f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}" - == simple_task.get_extra_links(ti, AIPlatformConsoleLink.name) - ) - class TestMLEngineTrainingCancelJobOperator(unittest.TestCase): @@ -637,7 +579,7 @@ def test_success_cancel_training_job(self, mock_hook): hook_instance.cancel_job.return_value = success_response cancel_training_op = MLEngineTrainingCancelJobOperator(**self.TRAINING_DEFAULT_ARGS) - cancel_training_op.execute(None) + cancel_training_op.execute(context=MagicMock()) mock_hook.assert_called_once_with( gcp_conn_id='google_cloud_default', @@ -660,7 +602,7 @@ def test_http_error(self, mock_hook): with pytest.raises(HttpError) as ctx: cancel_training_op = MLEngineTrainingCancelJobOperator(**self.TRAINING_DEFAULT_ARGS) - cancel_training_op.execute(None) + cancel_training_op.execute(context=MagicMock()) mock_hook.assert_called_once_with( gcp_conn_id='google_cloud_default', @@ -688,7 +630,7 @@ def test_success_create_model(self, mock_hook): impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - task.execute(None) + task.execute(context=MagicMock()) mock_hook.assert_called_once_with( delegate_to=TEST_DELEGATE_TO, @@ -711,7 +653,7 @@ def test_success_get_model(self, mock_hook): impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - result = task.execute(None) + result = task.execute(context=MagicMock()) mock_hook.assert_called_once_with( delegate_to=TEST_DELEGATE_TO, @@ -749,7 +691,7 @@ def test_success_create_model(self, mock_hook): impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - task.execute(None) + task.execute(context=MagicMock()) mock_hook.assert_called_once_with( delegate_to=TEST_DELEGATE_TO, @@ -773,7 +715,7 @@ def test_success_get_model(self, mock_hook): impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - result = task.execute(None) + result = task.execute(context=MagicMock()) mock_hook.assert_called_once_with( delegate_to=TEST_DELEGATE_TO, @@ -799,7 +741,7 @@ def test_success_delete_model(self, mock_hook): delete_contents=True, ) - task.execute(None) + task.execute(context=MagicMock()) mock_hook.assert_called_once_with( delegate_to=TEST_DELEGATE_TO, @@ -852,7 +794,7 @@ def test_success(self, mock_hook): impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - task.execute(None) + task.execute(context=MagicMock()) mock_hook.assert_called_once_with( delegate_to=TEST_DELEGATE_TO, @@ -899,7 +841,7 @@ def test_success(self, mock_hook): impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - task.execute(None) + task.execute(context=MagicMock()) mock_hook.assert_called_once_with( delegate_to=TEST_DELEGATE_TO, @@ -945,7 +887,7 @@ def test_success(self, mock_hook): impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - task.execute(None) + task.execute(context=MagicMock()) mock_hook.assert_called_once_with( delegate_to=TEST_DELEGATE_TO, @@ -981,7 +923,7 @@ def test_success(self, mock_hook): impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - task.execute(None) + task.execute(context=MagicMock()) mock_hook.assert_called_once_with( delegate_to=TEST_DELEGATE_TO,