diff --git a/providers/google/src/airflow/providers/google/cloud/operators/compute.py b/providers/google/src/airflow/providers/google/cloud/operators/compute.py index cba9dd28b8c89..9ae10b5c0917c 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/compute.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/compute.py @@ -124,6 +124,9 @@ class ComputeEngineInsertInstanceOperator(ComputeEngineBaseOperator): :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if `retry` is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. + :param recreate_if_machine_type_different: When True, delete and recreate the instance if + the existing machine type differs from the requested body. Defaults to + False, in which case differences are only logged. """ operator_extra_links = (ComputeInstanceDetailsLink(),) @@ -156,6 +159,7 @@ def __init__( api_version: str = "v1", validate_body: bool = True, impersonation_chain: str | Sequence[str] | None = None, + recreate_if_machine_type_different: bool = False, **kwargs, ) -> None: self.body = body @@ -167,6 +171,7 @@ def __init__( self.retry = retry self.timeout = timeout self.metadata = metadata + self.recreate_if_machine_type_different = recreate_if_machine_type_different if validate_body: self._field_validator = GcpBodyFieldValidator( @@ -206,7 +211,68 @@ def _validate_all_body_fields(self) -> None: if self._field_validator: self._field_validator.validate(self.body) + def _extract_machine_type(self, value: str | None) -> str | None: + if not value: + return None + return value.split("/")[-1] + + def _detect_instance_drift(self, existing: Instance) -> dict[str, Any]: + """Detect machine type differences between the existing instance and the requested body.""" + diffs = {} + + # Compare machine_type. + requested_machine_type = self.body.get("machine_type") + existing_machine_type = getattr(existing, "machine_type", None) + + requested_name = self._extract_machine_type(requested_machine_type) + existing_name = self._extract_machine_type(existing_machine_type) + + if requested_name and existing_name and requested_name != existing_name: + diffs["machine_type"] = { + "existing": existing_name, + "requested": requested_name, + } + + return diffs + + def _create_instance(self, hook: ComputeEngineHook, context: Context) -> dict: + """Create the instance using the current body and return the created instance as dict.""" + self._field_sanitizer.sanitize(self.body) + + self.log.info("Creating Instance with specified body: %s", self.body) + + hook.insert_instance( + body=self.body, + request_id=self.request_id, + project_id=self.project_id, + zone=self.zone, + ) + + self.log.info("The specified Instance has been created SUCCESSFULLY") + + new_instance = hook.get_instance( + resource_id=self.resource_id, + project_id=self.project_id, + zone=self.zone, + ) + + ComputeInstanceDetailsLink.persist( + context=context, + project_id=self.project_id or hook.project_id, + ) + + return Instance.to_dict(new_instance) + def execute(self, context: Context) -> dict: + """ + Ensure that a Compute Engine instance with the given name exists. + + If the instance does not exist, it is created. If it already exists, + presence is treated as success (presence-based idempotence). + + If machine type drift is detected and ``recreate_if_machine_type_different=True``, + the existing instance is deleted and recreated using the requested body. + """ hook = ComputeEngineHook( gcp_conn_id=self.gcp_conn_id, api_version=self.api_version, @@ -214,46 +280,54 @@ def execute(self, context: Context) -> dict: ) self._validate_all_body_fields() self.check_body_fields() + try: - # Idempotence check (sort of) - we want to check if the new Instance - # is already created and if is, then we assume it was created previously - we do - # not check if content of the Instance is as expected. - # We assume success if the Instance is simply present. existing_instance = hook.get_instance( resource_id=self.resource_id, project_id=self.project_id, zone=self.zone, ) except exceptions.NotFound as e: - # We actually expect to get 404 / Not Found here as the should not yet exist + # We expect a 404 here if the instance does not yet exist. if e.code != 404: raise e - else: - self.log.info("The %s Instance already exists", self.resource_id) - ComputeInstanceDetailsLink.persist( - context=context, - project_id=self.project_id or hook.project_id, + + # Create instance if it does not exist. + return self._create_instance(hook, context) + + # Instance already exists. + self.log.info("The %s Instance already exists", self.resource_id) + + # Detect drift. + diffs = self._detect_instance_drift(existing_instance) + if diffs: + self.log.warning( + "Existing instance '%s' differs from requested configuration: %s", + self.resource_id, + diffs, ) - return Instance.to_dict(existing_instance) - self._field_sanitizer.sanitize(self.body) - self.log.info("Creating Instance with specified body: %s", self.body) - hook.insert_instance( - body=self.body, - request_id=self.request_id, - project_id=self.project_id, - zone=self.zone, - ) - self.log.info("The specified Instance has been created SUCCESSFULLY") - new_instance = hook.get_instance( - resource_id=self.resource_id, - project_id=self.project_id, - zone=self.zone, - ) + + if self.recreate_if_machine_type_different: + self.log.info( + "Recreating instance '%s' because recreate_if_machine_type_different=True", + self.resource_id, + ) + + hook.delete_instance( + resource_id=self.resource_id, + project_id=self.project_id, + request_id=self.request_id, + zone=self.zone, + ) + + return self._create_instance(hook, context) + ComputeInstanceDetailsLink.persist( context=context, project_id=self.project_id or hook.project_id, ) - return Instance.to_dict(new_instance) + + return Instance.to_dict(existing_instance) class ComputeEngineInsertInstanceFromTemplateOperator(ComputeEngineBaseOperator): diff --git a/providers/google/tests/system/google/cloud/compute/example_compute_recreate_drift.py b/providers/google/tests/system/google/cloud/compute/example_compute_recreate_drift.py new file mode 100644 index 0000000000000..b63c76f979858 --- /dev/null +++ b/providers/google/tests/system/google/cloud/compute/example_compute_recreate_drift.py @@ -0,0 +1,139 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +System test for ComputeEngineInsertInstanceOperator +verifying recreate_if_machine_type_different=True recreates the +correct machine_type instance when machine_type drifts. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.operators.python import PythonOperator +from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook +from airflow.providers.google.cloud.operators.compute import ( + ComputeEngineDeleteInstanceOperator, + ComputeEngineInsertInstanceOperator, +) + +try: + from airflow.sdk import TriggerRule +except ImportError: + from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef] + +from system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID + +DAG_ID = "cloud_compute_insert_recreate_if_different" +LOCATION = "us-central1-a" + +INSTANCE_NAME = f"airflow-drift-test-{ENV_ID}" +MACHINE_TYPE_A = "n1-standard-1" +MACHINE_TYPE_B = "n1-standard-2" + +BASE_BODY = { + "name": INSTANCE_NAME, + "disks": [ + { + "boot": True, + "auto_delete": True, + "initialize_params": { + "disk_size_gb": "10", + "source_image": "projects/debian-cloud/global/images/family/debian-12", + }, + } + ], + "network_interfaces": [{"network": "global/networks/default"}], +} + + +def assert_machine_type(): + hook = ComputeEngineHook() + instance = hook.get_instance( + project_id=PROJECT_ID, + zone=LOCATION, + resource_id=INSTANCE_NAME, + ) + + machine_type = instance.machine_type.split("/")[-1] + + assert machine_type == MACHINE_TYPE_B, f"Expected machine type {MACHINE_TYPE_B}, got {machine_type}" + + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "compute"], +) as dag: + # Step 1: Create with machine type A. + create_instance = ComputeEngineInsertInstanceOperator( + task_id="create_instance", + project_id=PROJECT_ID, + zone=LOCATION, + body={ + **BASE_BODY, + "machine_type": f"zones/{LOCATION}/machineTypes/{MACHINE_TYPE_A}", + }, + ) + + # Step 2: Re-run with different machine type and recreate recreate_if_machine_type_different=True. + recreate_instance = ComputeEngineInsertInstanceOperator( + task_id="recreate_instance", + project_id=PROJECT_ID, + zone=LOCATION, + body={ + **BASE_BODY, + "machine_type": f"zones/{LOCATION}/machineTypes/{MACHINE_TYPE_B}", + }, + recreate_if_machine_type_different=True, + ) + + # Step 3: Validate new machine type. + validate_machine_type = PythonOperator( + task_id="validate_machine_type", + python_callable=assert_machine_type, + ) + + # Step 4: Cleanup. + delete_instance = ComputeEngineDeleteInstanceOperator( + task_id="delete_instance", + project_id=PROJECT_ID, + zone=LOCATION, + resource_id=INSTANCE_NAME, + trigger_rule=TriggerRule.ALL_DONE, + ) + + create_instance >> recreate_instance >> validate_machine_type >> delete_instance + + # Everything below this line is required for system tests. + from tests_common.test_utils.watcher import watcher + + list(dag.tasks) >> watcher() + + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +test_run = get_test_run(dag) diff --git a/providers/google/tests/unit/google/cloud/operators/test_compute.py b/providers/google/tests/unit/google/cloud/operators/test_compute.py index 5d0c8c65534ae..5d8bc187b07fd 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_compute.py +++ b/providers/google/tests/unit/google/cloud/operators/test_compute.py @@ -18,6 +18,7 @@ from __future__ import annotations import ast +import logging from copy import deepcopy from unittest import mock @@ -255,6 +256,86 @@ def test_insert_instance_should_not_throw_ex_when_name_is_templated(self, mock_h request_id=None, ) + @mock.patch(COMPUTE_ENGINE_HOOK_PATH) + def test_insert_instance_should_recreate_on_drift(self, mock_hook): + + get_instance_obj_mock = mock.MagicMock() + get_instance_obj_mock.__class__ = Instance + + # Set existing machine_type config. + get_instance_obj_mock.machine_type = "zones/zone/machineTypes/old-type" + + mock_hook.return_value.get_instance.side_effect = [ + get_instance_obj_mock, # First existence check. + get_instance_obj_mock, # After recreation fetch. + ] + + body = deepcopy(GCE_INSTANCE_BODY_API_CALL) + + # Set config for new machine_type. + body["machine_type"] = "zones/zone/machineTypes/new-type" + + op = ComputeEngineInsertInstanceOperator( + project_id=GCP_PROJECT_ID, + body=body, + zone=GCE_ZONE, + task_id=TASK_ID, + recreate_if_machine_type_different=True, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + op.execute(context=mock.MagicMock()) + + mock_hook.return_value.delete_instance.assert_called_once_with( + resource_id=op.resource_id, + project_id=GCP_PROJECT_ID, + request_id=None, + zone=GCE_ZONE, + ) + + mock_hook.return_value.insert_instance.assert_called_once_with( + body=body, + request_id=None, + project_id=GCP_PROJECT_ID, + zone=GCE_ZONE, + ) + + @mock.patch(COMPUTE_ENGINE_HOOK_PATH) + def test_insert_instance_logs_drift(self, mock_hook, caplog): + get_instance_obj_mock = mock.MagicMock() + get_instance_obj_mock.__class__ = Instance + + # Set existing machine_type config. + get_instance_obj_mock.machine_type = "zones/zone/machineTypes/old-type" + + mock_hook.return_value.get_instance.return_value = get_instance_obj_mock + + body = deepcopy(GCE_INSTANCE_BODY_API_CALL) + + # Set config for new machine_type. + body["machine_type"] = "zones/zone/machineTypes/new-type" + + op = ComputeEngineInsertInstanceOperator( + project_id=GCP_PROJECT_ID, + resource_id=GCE_RESOURCE_ID, + body=body, + zone=GCE_ZONE, + task_id=TASK_ID, + recreate_if_machine_type_different=False, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + with caplog.at_level(logging.WARNING): + op.execute(context=mock.MagicMock()) + + assert any("differs from requested configuration" in r.message for r in caplog.records) + + # Ensure that no instances are deleted or created. + mock_hook.return_value.delete_instance.assert_not_called() + mock_hook.return_value.insert_instance.assert_not_called() + class TestGceInstanceInsertFromTemplate: @mock.patch(COMPUTE_ENGINE_HOOK_PATH)