From 107c36da95ff1b8f4ed35694fe687acc9bb3d517 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Fri, 31 Jan 2025 17:38:01 -0800 Subject: [PATCH 01/19] Implement hook to invoke REST API for MWAA --- .../providers/amazon/aws/hooks/mwaa.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 providers/src/airflow/providers/amazon/aws/hooks/mwaa.py diff --git a/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py b/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py new file mode 100644 index 0000000000000..fd484c7a21b49 --- /dev/null +++ b/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains AWS MWAA hook.""" + +from __future__ import annotations +import botocore.exceptions + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class MwaaHook(AwsBaseHook): + """ + Interact with AWS Manager Workflows for Apache Airflow + + Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") ` + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(client_type="mwaa", *args, **kwargs) + + def invoke_rest_api(self, + env_name: str, + path: str, + method: str, + body: dict | None = None, + query_params: dict | None = None + ) -> dict: + """ + Invoke the REST API on the Airflow webserver with the specified inputs. + + .. seealso:: + - :external+boto3:py:meth:`MWAA.Client.invoke_rest_api` + + :param env_name: name of the MWAA environment + :param path: Apache Airflow REST API endpoint path to be called + :param method: HTTP method used for making Airflow REST API calls + :param body: Request body for the Apache Airflow REST API call + :param query_params: Query parameters to be included in the Apache Airflow REST API call + """ + api_kwargs = { + "Name": env_name, + "Path": path, + "Method": method, + # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise + "Body": {k: v for k, v in body.items() if v is not None}, + "QueryParameters": query_params if query_params else {} + } + try: + result = self.get_conn().invoke_rest_api(**api_kwargs) + result.pop("ResponseMetadata") + return result + except botocore.exceptions.ClientError as e: + to_log = e.response + to_log.pop("ResponseMetadata") + to_log.pop("Error") + self.log.error(to_log) + raise e From fb9f130f397fa599a4be4cc4b06d0182025600d9 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Fri, 31 Jan 2025 17:38:36 -0800 Subject: [PATCH 02/19] Implement MwaaTriggerDagRunOperator --- .../providers/amazon/aws/operators/mwaa.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 providers/src/airflow/providers/amazon/aws/operators/mwaa.py diff --git a/providers/src/airflow/providers/amazon/aws/operators/mwaa.py b/providers/src/airflow/providers/amazon/aws/operators/mwaa.py new file mode 100644 index 0000000000000..1a89955b2a4b7 --- /dev/null +++ b/providers/src/airflow/providers/amazon/aws/operators/mwaa.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains AWS MWAA operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]): + """ + Trigger a Dag Run for a Dag in an Amazon MWAA environment. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:MwaaTriggerDagRunOperator` + + :param env_name: The MWAA environment name (templated) + :param trigger_dag_id: The ID of the DAG to be triggered (templated) + :param trigger_run_id: The Run ID. The value of this field can be set only when creating the object. This + together with trigger_dag_id are a unique key. (templated) + :param logical_date: The logical date (previously called execution date). This is the time or interval + covered by this DAG run, according to the DAG definition. The value of this field can be set only when + creating the object. This together with trigger_dag_id are a unique key. (templated) + :param data_interval_start: The beginning of the interval the DAG run covers + :param data_interval_end: The end of the interval the DAG run covers + :param conf: Additional configuration parameters. The value of this field can be set only when creating + the object. (templated) + :param note: Contains manually entered notes by the user about the DagRun. (templated) + """ + + aws_hook_class = MwaaHook + template_fields: Sequence[str] = aws_template_fields( + "env_name", + "trigger_dag_id", + "trigger_run_id", + "logical_date", + "data_interval_start", + "data_interval_end", + "conf", + "note", + ) + template_fields_renderers = {"conf": "json"} + ui_color = "#6ad3fa" + + def __init__( + self, + *, + env_name: str, + trigger_dag_id: str, + trigger_run_id: str | None = None, + logical_date: str | None = None, + data_interval_start: str | None = None, + data_interval_end: str | None = None, + conf: dict | None = None, + note: str | None = None, + **kwargs + ): + super().__init__(**kwargs) + self.env_name = env_name + self.trigger_dag_id = trigger_dag_id + self.trigger_run_id = trigger_run_id + self.logical_date = logical_date + self.data_interval_start = data_interval_start + self.data_interval_end = data_interval_end + self.conf = conf if conf else {} + self.note = note + + def execute(self, context: Context) -> dict: + """ + Trigger a Dag Run for the Dag in the Amazon MWAA environment. + + :param context: the Context object + :return: dict with information about the Dag run + For details of the returned dict, see :py:meth:`botocore.client.MWAA.invoke_rest_api` + """ + return self.hook.invoke_rest_api( + env_name=self.env_name, + path=f"/dags/{self.trigger_dag_id}/dagRuns", + method="POST", + body={ + "dag_run_id": self.trigger_run_id, + "logical_date": self.logical_date, + "data_interval_start": self.data_interval_start, + "data_interval_end": self.data_interval_end, + "conf": self.conf, + "note": self.note + } + ) From 24c2b182e84eb3d43daac45a7998796c65ef8eb6 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Fri, 31 Jan 2025 17:39:08 -0800 Subject: [PATCH 03/19] Add system test for MWAA --- .../tests/system/amazon/aws/example_mwaa.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 providers/tests/system/amazon/aws/example_mwaa.py diff --git a/providers/tests/system/amazon/aws/example_mwaa.py b/providers/tests/system/amazon/aws/example_mwaa.py new file mode 100644 index 0000000000000..3b620abcfc634 --- /dev/null +++ b/providers/tests/system/amazon/aws/example_mwaa.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow.models.baseoperator import chain +from airflow.models.dag import DAG +from airflow.providers.amazon.aws.operators.mwaa import MwaaTriggerDagRunOperator + +from providers.tests.system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder + + +DAG_ID = "example_mwaa" + +# Externally fetched variables: +EXISTING_ENVIRONMENT_NAME_KEY = "ENVIRONMENT_NAME" +EXISTING_DAG_ID_KEY = "TRIGGER_DAG_ID" + + +sys_test_context_task = ( + SystemTestContextBuilder() + # NOTE: Creating a functional MWAA environment is time-consuming and requires + # manually creating and configuring an S3 bucket for DAG storage and a VPC with + # private subnets which is out of scope for this demo. To simplify this demo and + # make it run in a reasonable time, follow these steps in the AWS Console to create + # a new MWAA environment with default configuration: + # 1. Create an S3 bucket and upload your DAGs to a 'dags' directory + # 2. Navigate to the MWAA console + # 3. Create an environment, making sure to use the S3 bucket from the previous step + # 4. Use the default VPC/network settings (or create a new one using this guide: + # https://docs.aws.amazon.com/mwaa/latest/userguide/vpc-create.html) + # 5. Select Public network for web server access and click through to creation + .add_variable(EXISTING_ENVIRONMENT_NAME_KEY) + .add_variable(EXISTING_DAG_ID_KEY) + .build() +) + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + tags=["example"], + catchup=False, +) as dag: + test_context = sys_test_context_task() + env_id = test_context[ENV_ID_KEY] + env_name = test_context["ENVIRONMENT_NAME"] + trigger_dag_id = test_context["TRIGGER_DAG_ID"] + + # [START howto_operator_mwaa_trigger_dag_run] + trigger_dag_run = MwaaTriggerDagRunOperator( + task_id='trigger_dag_run', + env_name=env_name, + trigger_dag_id=trigger_dag_id, + ) + # [END howto_operator_mwaa_trigger_dag_run] + + chain( + # TEST SETUP + test_context, + # TEST BODY + trigger_dag_run + ) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) From bd4cae85fee8fd5e47a06b68b4a36a2d3117b6e7 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Fri, 31 Jan 2025 17:39:33 -0800 Subject: [PATCH 04/19] Create doc for MWAA operators --- .../operators/mwaa.rst | 60 +++++++++++++++++++ docs/spelling_wordlist.txt | 1 + 2 files changed, 61 insertions(+) create mode 100644 docs/apache-airflow-providers-amazon/operators/mwaa.rst diff --git a/docs/apache-airflow-providers-amazon/operators/mwaa.rst b/docs/apache-airflow-providers-amazon/operators/mwaa.rst new file mode 100644 index 0000000000000..cbe6944ce988f --- /dev/null +++ b/docs/apache-airflow-providers-amazon/operators/mwaa.rst @@ -0,0 +1,60 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +================================================== +Amazon Managed Workflows for Apache Airflow (MWAA) +================================================== + +`Amazon Managed Workflows for Apache Airflow (MWAA) `__ +is a managed service for Apache Airflow that lets you use your current, familiar Apache Airflow platform to orchestrate +your workflows. You gain improved scalability, availability, and security without the operational burden of managing +underlying infrastructure. + +Prerequisite Tasks +------------------ + +.. include:: ../_partials/prerequisite_tasks.rst + +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + +Operators +--------- + +.. _howto/operator:MwaaTriggerDagRunOperator: + +Trigger a DAG run in an Amazon MWAA environment +=============================================== + +To trigger a DAG run in an Amazon MWAA environment you can use the +:class:`~airflow.providers.amazon.aws.operators.mwaa.MwaaTriggerDagRunOperator` + +In the following example, the task ``trigger_dag_run`` triggers a dag run for a DAG with with the ID ``hello_world`` in +the environment ``MyAirflowEnvironment``. + +.. exampleinclude:: /../../providers/tests/system/amazon/aws/example_mwaa.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_mwaa_trigger_dag_run] + :end-before: [END howto_operator_mwaa_trigger_dag_run] + +References +---------- + +* `AWS boto3 library documentation for MWAA `__ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 2b35ed5bc2dd9..a6bde9a160226 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1099,6 +1099,7 @@ muldelete Multinamespace mutex mv +mwaa mypy Mysql mysql From 6fb9ee25b6569cdc69798d7a71d7d24439f2807c Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Mon, 3 Feb 2025 16:41:32 -0800 Subject: [PATCH 05/19] Add MWAA operators and hooks to provider.yaml --- providers/src/airflow/providers/amazon/provider.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/providers/src/airflow/providers/amazon/provider.yaml b/providers/src/airflow/providers/amazon/provider.yaml index d89532b848c1b..fa129909a7507 100644 --- a/providers/src/airflow/providers/amazon/provider.yaml +++ b/providers/src/airflow/providers/amazon/provider.yaml @@ -450,6 +450,9 @@ operators: - integration-name: Amazon Managed Service for Apache Flink python-modules: - airflow.providers.amazon.aws.operators.kinesis_analytics + - integration-name: Amazon Managed Workflows for Apache Airflow (MWAA) + python-modules: + - airflow.providers.amazon.aws.operators.mwaa - integration-name: Amazon Simple Storage Service (S3) python-modules: - airflow.providers.amazon.aws.operators.s3 @@ -658,6 +661,9 @@ hooks: - integration-name: Amazon CloudWatch Logs python-modules: - airflow.providers.amazon.aws.hooks.logs + - integration-name: Amazon Managed Workflows for Apache Airflow (MWAA) + python-modules: + - airflow.providers.amazon.aws.hooks.mwaa - integration-name: Amazon OpenSearch Serverless python-modules: - airflow.providers.amazon.aws.hooks.opensearch_serverless From a7f2106782c752df7c01b2215903c30c4bf7e1e9 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Mon, 3 Feb 2025 16:41:51 -0800 Subject: [PATCH 06/19] Fix pre-commit issues --- .../providers/amazon/aws/hooks/mwaa.py | 24 +++++++++------- .../providers/amazon/aws/operators/mwaa.py | 28 +++++++++---------- .../tests/system/amazon/aws/example_mwaa.py | 4 +-- 3 files changed, 30 insertions(+), 26 deletions(-) diff --git a/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py b/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py index fd484c7a21b49..737d0035beb3c 100644 --- a/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py +++ b/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py @@ -17,6 +17,7 @@ """This module contains AWS MWAA hook.""" from __future__ import annotations + import botocore.exceptions from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -24,7 +25,7 @@ class MwaaHook(AwsBaseHook): """ - Interact with AWS Manager Workflows for Apache Airflow + Interact with AWS Manager Workflows for Apache Airflow. Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") ` @@ -36,15 +37,17 @@ class MwaaHook(AwsBaseHook): """ def __init__(self, *args, **kwargs) -> None: - super().__init__(client_type="mwaa", *args, **kwargs) + kwargs["client_type"] = "mwaa" + super().__init__(*args, **kwargs) - def invoke_rest_api(self, - env_name: str, - path: str, - method: str, - body: dict | None = None, - query_params: dict | None = None - ) -> dict: + def invoke_rest_api( + self, + env_name: str, + path: str, + method: str, + body: dict | None = None, + query_params: dict | None = None, + ) -> dict: """ Invoke the REST API on the Airflow webserver with the specified inputs. @@ -57,13 +60,14 @@ def invoke_rest_api(self, :param body: Request body for the Apache Airflow REST API call :param query_params: Query parameters to be included in the Apache Airflow REST API call """ + body = body or {} api_kwargs = { "Name": env_name, "Path": path, "Method": method, # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise "Body": {k: v for k, v in body.items() if v is not None}, - "QueryParameters": query_params if query_params else {} + "QueryParameters": query_params if query_params else {}, } try: result = self.get_conn().invoke_rest_api(**api_kwargs) diff --git a/providers/src/airflow/providers/amazon/aws/operators/mwaa.py b/providers/src/airflow/providers/amazon/aws/operators/mwaa.py index 1a89955b2a4b7..c77fa61c676e9 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/mwaa.py +++ b/providers/src/airflow/providers/amazon/aws/operators/mwaa.py @@ -66,17 +66,17 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]): ui_color = "#6ad3fa" def __init__( - self, - *, - env_name: str, - trigger_dag_id: str, - trigger_run_id: str | None = None, - logical_date: str | None = None, - data_interval_start: str | None = None, - data_interval_end: str | None = None, - conf: dict | None = None, - note: str | None = None, - **kwargs + self, + *, + env_name: str, + trigger_dag_id: str, + trigger_run_id: str | None = None, + logical_date: str | None = None, + data_interval_start: str | None = None, + data_interval_end: str | None = None, + conf: dict | None = None, + note: str | None = None, + **kwargs, ): super().__init__(**kwargs) self.env_name = env_name @@ -96,7 +96,7 @@ def execute(self, context: Context) -> dict: :return: dict with information about the Dag run For details of the returned dict, see :py:meth:`botocore.client.MWAA.invoke_rest_api` """ - return self.hook.invoke_rest_api( + return self.hook.invoke_rest_api( env_name=self.env_name, path=f"/dags/{self.trigger_dag_id}/dagRuns", method="POST", @@ -106,6 +106,6 @@ def execute(self, context: Context) -> dict: "data_interval_start": self.data_interval_start, "data_interval_end": self.data_interval_end, "conf": self.conf, - "note": self.note - } + "note": self.note, + }, ) diff --git a/providers/tests/system/amazon/aws/example_mwaa.py b/providers/tests/system/amazon/aws/example_mwaa.py index 3b620abcfc634..8530fb77dc708 100644 --- a/providers/tests/system/amazon/aws/example_mwaa.py +++ b/providers/tests/system/amazon/aws/example_mwaa.py @@ -64,7 +64,7 @@ # [START howto_operator_mwaa_trigger_dag_run] trigger_dag_run = MwaaTriggerDagRunOperator( - task_id='trigger_dag_run', + task_id="trigger_dag_run", env_name=env_name, trigger_dag_id=trigger_dag_id, ) @@ -74,7 +74,7 @@ # TEST SETUP test_context, # TEST BODY - trigger_dag_run + trigger_dag_run, ) from tests_common.test_utils.watcher import watcher From 8c4b3ab9440035ff68b9bf6a5806b9a0994cd3a3 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Fri, 7 Feb 2025 17:28:22 -0800 Subject: [PATCH 07/19] Handle boto ClientError edge case --- .../src/airflow/providers/amazon/aws/hooks/mwaa.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py b/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py index 737d0035beb3c..85218db496462 100644 --- a/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py +++ b/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py @@ -18,7 +18,7 @@ from __future__ import annotations -import botocore.exceptions +from botocore.exceptions import ClientError from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -71,11 +71,11 @@ def invoke_rest_api( } try: result = self.get_conn().invoke_rest_api(**api_kwargs) - result.pop("ResponseMetadata") + result.pop("ResponseMetadata", None) return result - except botocore.exceptions.ClientError as e: + except ClientError as e: to_log = e.response - to_log.pop("ResponseMetadata") - to_log.pop("Error") + to_log.pop("ResponseMetadata", None) + to_log.pop("Error", None) self.log.error(to_log) raise e From 2dc043be758b63bc92cd9c22df0430242e39f271 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Fri, 7 Feb 2025 17:28:58 -0800 Subject: [PATCH 08/19] Fix fetching context keys --- providers/tests/system/amazon/aws/example_mwaa.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/providers/tests/system/amazon/aws/example_mwaa.py b/providers/tests/system/amazon/aws/example_mwaa.py index 8530fb77dc708..328ed574fc128 100644 --- a/providers/tests/system/amazon/aws/example_mwaa.py +++ b/providers/tests/system/amazon/aws/example_mwaa.py @@ -29,7 +29,7 @@ # Externally fetched variables: EXISTING_ENVIRONMENT_NAME_KEY = "ENVIRONMENT_NAME" -EXISTING_DAG_ID_KEY = "TRIGGER_DAG_ID" +EXISTING_DAG_ID_KEY = "DAG_ID" sys_test_context_task = ( @@ -59,8 +59,8 @@ ) as dag: test_context = sys_test_context_task() env_id = test_context[ENV_ID_KEY] - env_name = test_context["ENVIRONMENT_NAME"] - trigger_dag_id = test_context["TRIGGER_DAG_ID"] + env_name = test_context[EXISTING_ENVIRONMENT_NAME_KEY] + trigger_dag_id = test_context[EXISTING_DAG_ID_KEY] # [START howto_operator_mwaa_trigger_dag_run] trigger_dag_run = MwaaTriggerDagRunOperator( From 051bd49bdde8e833d98f0cfe7229a5040a59db71 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Fri, 7 Feb 2025 17:29:23 -0800 Subject: [PATCH 09/19] Implement unit tests for MWAA operator and hook --- providers/tests/amazon/aws/hooks/test_mwaa.py | 112 ++++++++++++++++++ .../tests/amazon/aws/operators/test_mwaa.py | 73 ++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 providers/tests/amazon/aws/hooks/test_mwaa.py create mode 100644 providers/tests/amazon/aws/operators/test_mwaa.py diff --git a/providers/tests/amazon/aws/hooks/test_mwaa.py b/providers/tests/amazon/aws/hooks/test_mwaa.py new file mode 100644 index 0000000000000..bd6afbced280c --- /dev/null +++ b/providers/tests/amazon/aws/hooks/test_mwaa.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest +from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook +from botocore.exceptions import ClientError +from moto import mock_aws + +ENV_NAME = "test_env" +PATH = "/dags/test_dag/dagRuns" +METHOD = "POST" +QUERY_PARAMS = {"limit": 30} + + +class TestMwaaHook: + def setup_method(self): + self.hook = MwaaHook() + + def test_init(self): + assert self.hook.client_type == "mwaa" + + @mock_aws + def test_get_conn(self): + assert self.hook.get_conn() is not None + + @pytest.mark.parametrize( + "body", [ + None, # test case: empty body + {"conf": {}} # test case: non-empty body + ] + ) + @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.MwaaHook.get_conn") + def test_invoke_rest_api_success(self, mock_conn, body) -> None: + boto_invoke_mock = mock.MagicMock(return_value=self.example_responses["success"]) + mock_conn.return_value.invoke_rest_api = boto_invoke_mock + + retval = self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD, body, QUERY_PARAMS) + kwargs_to_assert = {"Name": ENV_NAME, "Path": PATH, "Method": METHOD, "Body": body if body else {}, + "QueryParameters": QUERY_PARAMS} + boto_invoke_mock.assert_called_once_with(**kwargs_to_assert) + assert retval == {k: v for k, v in self.example_responses["success"].items() if + k != "ResponseMetadata"} + + @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.MwaaHook.get_conn") + def test_invoke_rest_api_failure(self, mock_conn) -> None: + error = ClientError( + error_response=self.example_responses["failure"], + operation_name="invoke_rest_api" + ) + boto_invoke_mock = mock.MagicMock(side_effect=error) + mock_conn.return_value.invoke_rest_api = boto_invoke_mock + mock_log = mock.MagicMock() + self.hook.log.error = mock_log + + with pytest.raises(ClientError) as caught_error: + self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD) + assert caught_error == error + + expected_log = {k: v for k, v in self.example_responses["failure"].items() + if k != "ResponseMetadata" and k != "Error"} + mock_log.assert_called_once_with(expected_log) + + @pytest.fixture(autouse=True) + def _setup_test_cases(self): + self.example_responses = { + "success": { + 'ResponseMetadata': {'RequestId': 'some ID', 'HTTPStatusCode': 200, + 'HTTPHeaders': {'header1': 'value1'}, 'RetryAttempts': 0}, + 'RestApiStatusCode': 200, + 'RestApiResponse': { + 'conf': {}, + 'dag_id': 'hello_world', + 'dag_run_id': 'manual__2025-02-08T00:33:09.457198+00:00', + 'data_interval_end': '2025-02-08T00:33:09.457198+00:00', + 'data_interval_start': '2025-02-08T00:33:09.457198+00:00', + 'execution_date': '2025-02-08T00:33:09.457198+00:00', + 'external_trigger': True, + 'logical_date': '2025-02-08T00:33:09.457198+00:00', + 'run_type': 'manual', + 'state': 'queued' + } + }, + "failure": { + 'Error': {'Message': '', 'Code': 'RestApiClientException'}, + 'ResponseMetadata': {'RequestId': 'some ID', 'HTTPStatusCode': 400, + 'HTTPHeaders': {'header1': 'value1'}, 'RetryAttempts': 0}, + 'RestApiStatusCode': 404, + 'RestApiResponse': { + 'detail': "DAG with dag_id: 'hello_world1' not found", + 'status': 404, + 'title': 'DAG not found', + 'type': 'https://airflow.apache.org/docs/apache-airflow/2.10.3/stable-rest-api-ref.html#section/Errors/NotFound' + } + } + } diff --git a/providers/tests/amazon/aws/operators/test_mwaa.py b/providers/tests/amazon/aws/operators/test_mwaa.py new file mode 100644 index 0000000000000..ffe054b46df88 --- /dev/null +++ b/providers/tests/amazon/aws/operators/test_mwaa.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.providers.amazon.aws.operators.mwaa import MwaaTriggerDagRunOperator + +from providers.tests.amazon.aws.utils.test_template_fields import validate_template_fields + +OP_KWARGS = { + "task_id": "test_task", + "env_name": "test_env", + "trigger_dag_id": "test_dag_id", + "trigger_run_id": "test_run_id", + "logical_date": "2025-01-01T00:00:01Z", + "data_interval_start": "2025-01-02T00:00:01Z", + "data_interval_end": "2025-01-03T00:00:01Z", + "conf": {"key": "value"}, + "note": "test note" +} +RETURN_VALUE = {"RestApiCode": 200, "RestApiResponse": "test response"} + +class TestMwaaTriggerDagRunOperator: + def test_init(self): + op = MwaaTriggerDagRunOperator(**OP_KWARGS) + assert op.env_name == OP_KWARGS["env_name"] + assert op.trigger_dag_id == OP_KWARGS["trigger_dag_id"] + assert op.trigger_run_id is OP_KWARGS["trigger_run_id"] + assert op.logical_date is OP_KWARGS["logical_date"] + assert op.data_interval_start is OP_KWARGS["data_interval_start"] + assert op.data_interval_end is OP_KWARGS["data_interval_end"] + assert op.conf == OP_KWARGS["conf"] + assert op.note is OP_KWARGS["note"] + + @mock.patch.object(MwaaTriggerDagRunOperator, "hook") + def test_execute(self, mock_hook): + mock_hook.invoke_rest_api.return_value = RETURN_VALUE + op = MwaaTriggerDagRunOperator(**OP_KWARGS) + ret_val = op.execute({}) + + mock_hook.invoke_rest_api.assert_called_once_with( + env_name=OP_KWARGS["env_name"], + path=f"/dags/{OP_KWARGS['trigger_dag_id']}/dagRuns", + method="POST", + body={ + "dag_run_id": OP_KWARGS["trigger_run_id"], + "logical_date": OP_KWARGS["logical_date"], + "data_interval_start": OP_KWARGS["data_interval_start"], + "data_interval_end": OP_KWARGS["data_interval_end"], + "conf": OP_KWARGS["conf"], + "note": OP_KWARGS["note"], + } + ) + assert ret_val == RETURN_VALUE + + def test_template_fields(self): + operator = MwaaTriggerDagRunOperator(**OP_KWARGS) + validate_template_fields(operator) From 486d077e15d7f83bc6974f0d9bb6742232fe8fee Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Fri, 7 Feb 2025 18:05:17 -0800 Subject: [PATCH 10/19] Fix linter issues --- .../airflow/providers/amazon/provider.yaml | 4 +- providers/tests/amazon/aws/hooks/test_mwaa.py | 96 +++++++++++-------- .../tests/amazon/aws/operators/test_mwaa.py | 5 +- .../tests/system/amazon/aws/example_mwaa.py | 1 - 4 files changed, 62 insertions(+), 44 deletions(-) diff --git a/providers/src/airflow/providers/amazon/provider.yaml b/providers/src/airflow/providers/amazon/provider.yaml index fa129909a7507..9003d38bae925 100644 --- a/providers/src/airflow/providers/amazon/provider.yaml +++ b/providers/src/airflow/providers/amazon/provider.yaml @@ -452,7 +452,7 @@ operators: - airflow.providers.amazon.aws.operators.kinesis_analytics - integration-name: Amazon Managed Workflows for Apache Airflow (MWAA) python-modules: - - airflow.providers.amazon.aws.operators.mwaa + - airflow.providers.amazon.aws.operators.mwaa - integration-name: Amazon Simple Storage Service (S3) python-modules: - airflow.providers.amazon.aws.operators.s3 @@ -663,7 +663,7 @@ hooks: - airflow.providers.amazon.aws.hooks.logs - integration-name: Amazon Managed Workflows for Apache Airflow (MWAA) python-modules: - - airflow.providers.amazon.aws.hooks.mwaa + - airflow.providers.amazon.aws.hooks.mwaa - integration-name: Amazon OpenSearch Serverless python-modules: - airflow.providers.amazon.aws.hooks.opensearch_serverless diff --git a/providers/tests/amazon/aws/hooks/test_mwaa.py b/providers/tests/amazon/aws/hooks/test_mwaa.py index bd6afbced280c..ad1d5483f115b 100644 --- a/providers/tests/amazon/aws/hooks/test_mwaa.py +++ b/providers/tests/amazon/aws/hooks/test_mwaa.py @@ -19,10 +19,11 @@ from unittest import mock import pytest -from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook from botocore.exceptions import ClientError from moto import mock_aws +from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook + ENV_NAME = "test_env" PATH = "/dags/test_dag/dagRuns" METHOD = "POST" @@ -41,10 +42,11 @@ def test_get_conn(self): assert self.hook.get_conn() is not None @pytest.mark.parametrize( - "body", [ + "body", + [ None, # test case: empty body - {"conf": {}} # test case: non-empty body - ] + {"conf": {}}, # test case: non-empty body + ], ) @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.MwaaHook.get_conn") def test_invoke_rest_api_success(self, mock_conn, body) -> None: @@ -52,17 +54,22 @@ def test_invoke_rest_api_success(self, mock_conn, body) -> None: mock_conn.return_value.invoke_rest_api = boto_invoke_mock retval = self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD, body, QUERY_PARAMS) - kwargs_to_assert = {"Name": ENV_NAME, "Path": PATH, "Method": METHOD, "Body": body if body else {}, - "QueryParameters": QUERY_PARAMS} + kwargs_to_assert = { + "Name": ENV_NAME, + "Path": PATH, + "Method": METHOD, + "Body": body if body else {}, + "QueryParameters": QUERY_PARAMS, + } boto_invoke_mock.assert_called_once_with(**kwargs_to_assert) - assert retval == {k: v for k, v in self.example_responses["success"].items() if - k != "ResponseMetadata"} + assert retval == { + k: v for k, v in self.example_responses["success"].items() if k != "ResponseMetadata" + } @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.MwaaHook.get_conn") def test_invoke_rest_api_failure(self, mock_conn) -> None: error = ClientError( - error_response=self.example_responses["failure"], - operation_name="invoke_rest_api" + error_response=self.example_responses["failure"], operation_name="invoke_rest_api" ) boto_invoke_mock = mock.MagicMock(side_effect=error) mock_conn.return_value.invoke_rest_api = boto_invoke_mock @@ -71,42 +78,53 @@ def test_invoke_rest_api_failure(self, mock_conn) -> None: with pytest.raises(ClientError) as caught_error: self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD) - assert caught_error == error - expected_log = {k: v for k, v in self.example_responses["failure"].items() - if k != "ResponseMetadata" and k != "Error"} + expected_log = { + k: v + for k, v in self.example_responses["failure"].items() + if k != "ResponseMetadata" and k != "Error" + } mock_log.assert_called_once_with(expected_log) + assert caught_error == error @pytest.fixture(autouse=True) def _setup_test_cases(self): self.example_responses = { "success": { - 'ResponseMetadata': {'RequestId': 'some ID', 'HTTPStatusCode': 200, - 'HTTPHeaders': {'header1': 'value1'}, 'RetryAttempts': 0}, - 'RestApiStatusCode': 200, - 'RestApiResponse': { - 'conf': {}, - 'dag_id': 'hello_world', - 'dag_run_id': 'manual__2025-02-08T00:33:09.457198+00:00', - 'data_interval_end': '2025-02-08T00:33:09.457198+00:00', - 'data_interval_start': '2025-02-08T00:33:09.457198+00:00', - 'execution_date': '2025-02-08T00:33:09.457198+00:00', - 'external_trigger': True, - 'logical_date': '2025-02-08T00:33:09.457198+00:00', - 'run_type': 'manual', - 'state': 'queued' - } + "ResponseMetadata": { + "RequestId": "some ID", + "HTTPStatusCode": 200, + "HTTPHeaders": {"header1": "value1"}, + "RetryAttempts": 0, + }, + "RestApiStatusCode": 200, + "RestApiResponse": { + "conf": {}, + "dag_id": "hello_world", + "dag_run_id": "manual__2025-02-08T00:33:09.457198+00:00", + "data_interval_end": "2025-02-08T00:33:09.457198+00:00", + "data_interval_start": "2025-02-08T00:33:09.457198+00:00", + "execution_date": "2025-02-08T00:33:09.457198+00:00", + "external_trigger": True, + "logical_date": "2025-02-08T00:33:09.457198+00:00", + "run_type": "manual", + "state": "queued", + }, }, "failure": { - 'Error': {'Message': '', 'Code': 'RestApiClientException'}, - 'ResponseMetadata': {'RequestId': 'some ID', 'HTTPStatusCode': 400, - 'HTTPHeaders': {'header1': 'value1'}, 'RetryAttempts': 0}, - 'RestApiStatusCode': 404, - 'RestApiResponse': { - 'detail': "DAG with dag_id: 'hello_world1' not found", - 'status': 404, - 'title': 'DAG not found', - 'type': 'https://airflow.apache.org/docs/apache-airflow/2.10.3/stable-rest-api-ref.html#section/Errors/NotFound' - } - } + "Error": {"Message": "", "Code": "RestApiClientException"}, + "ResponseMetadata": { + "RequestId": "some ID", + "HTTPStatusCode": 400, + "HTTPHeaders": {"header1": "value1"}, + "RetryAttempts": 0, + }, + "RestApiStatusCode": 404, + "RestApiResponse": { + "detail": "DAG with dag_id: 'hello_world1' not found", + "status": 404, + "title": "DAG not found", + "type": "https://airflow.apache.org/docs/apache-airflow/2.10.3/stable-rest-api-ref.html#section/Errors/NotFound", + }, + }, } diff --git a/providers/tests/amazon/aws/operators/test_mwaa.py b/providers/tests/amazon/aws/operators/test_mwaa.py index ffe054b46df88..5ccdcb2dc9d97 100644 --- a/providers/tests/amazon/aws/operators/test_mwaa.py +++ b/providers/tests/amazon/aws/operators/test_mwaa.py @@ -31,10 +31,11 @@ "data_interval_start": "2025-01-02T00:00:01Z", "data_interval_end": "2025-01-03T00:00:01Z", "conf": {"key": "value"}, - "note": "test note" + "note": "test note", } RETURN_VALUE = {"RestApiCode": 200, "RestApiResponse": "test response"} + class TestMwaaTriggerDagRunOperator: def test_init(self): op = MwaaTriggerDagRunOperator(**OP_KWARGS) @@ -64,7 +65,7 @@ def test_execute(self, mock_hook): "data_interval_end": OP_KWARGS["data_interval_end"], "conf": OP_KWARGS["conf"], "note": OP_KWARGS["note"], - } + }, ) assert ret_val == RETURN_VALUE diff --git a/providers/tests/system/amazon/aws/example_mwaa.py b/providers/tests/system/amazon/aws/example_mwaa.py index 328ed574fc128..5668174464dc1 100644 --- a/providers/tests/system/amazon/aws/example_mwaa.py +++ b/providers/tests/system/amazon/aws/example_mwaa.py @@ -24,7 +24,6 @@ from providers.tests.system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder - DAG_ID = "example_mwaa" # Externally fetched variables: From 059040a9365f53a7944d526be4da6d4601f8d0a2 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Fri, 7 Feb 2025 18:13:53 -0800 Subject: [PATCH 11/19] Fix exception comparison in MWAA hook failure test --- providers/tests/amazon/aws/hooks/test_mwaa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/tests/amazon/aws/hooks/test_mwaa.py b/providers/tests/amazon/aws/hooks/test_mwaa.py index ad1d5483f115b..7720dbdb4efe5 100644 --- a/providers/tests/amazon/aws/hooks/test_mwaa.py +++ b/providers/tests/amazon/aws/hooks/test_mwaa.py @@ -85,7 +85,7 @@ def test_invoke_rest_api_failure(self, mock_conn) -> None: if k != "ResponseMetadata" and k != "Error" } mock_log.assert_called_once_with(expected_log) - assert caught_error == error + assert caught_error.value == error @pytest.fixture(autouse=True) def _setup_test_cases(self): From 47137922af1411d9833af3fa905f6333b1b35a04 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Sat, 8 Feb 2025 19:21:20 -0800 Subject: [PATCH 12/19] Minor testing improvements in MWAA integration --- providers/tests/amazon/aws/hooks/test_mwaa.py | 6 ++--- .../tests/amazon/aws/operators/test_mwaa.py | 25 ++++++++++++------- .../tests/system/amazon/aws/example_mwaa.py | 3 +-- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/providers/tests/amazon/aws/hooks/test_mwaa.py b/providers/tests/amazon/aws/hooks/test_mwaa.py index 7720dbdb4efe5..77af887be40ff 100644 --- a/providers/tests/amazon/aws/hooks/test_mwaa.py +++ b/providers/tests/amazon/aws/hooks/test_mwaa.py @@ -48,7 +48,7 @@ def test_get_conn(self): {"conf": {}}, # test case: non-empty body ], ) - @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.MwaaHook.get_conn") + @mock.patch.object(MwaaHook, "get_conn") def test_invoke_rest_api_success(self, mock_conn, body) -> None: boto_invoke_mock = mock.MagicMock(return_value=self.example_responses["success"]) mock_conn.return_value.invoke_rest_api = boto_invoke_mock @@ -66,7 +66,7 @@ def test_invoke_rest_api_success(self, mock_conn, body) -> None: k: v for k, v in self.example_responses["success"].items() if k != "ResponseMetadata" } - @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.MwaaHook.get_conn") + @mock.patch.object(MwaaHook, "get_conn") def test_invoke_rest_api_failure(self, mock_conn) -> None: error = ClientError( error_response=self.example_responses["failure"], operation_name="invoke_rest_api" @@ -79,13 +79,13 @@ def test_invoke_rest_api_failure(self, mock_conn) -> None: with pytest.raises(ClientError) as caught_error: self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD) + assert caught_error.value == error expected_log = { k: v for k, v in self.example_responses["failure"].items() if k != "ResponseMetadata" and k != "Error" } mock_log.assert_called_once_with(expected_log) - assert caught_error.value == error @pytest.fixture(autouse=True) def _setup_test_cases(self): diff --git a/providers/tests/amazon/aws/operators/test_mwaa.py b/providers/tests/amazon/aws/operators/test_mwaa.py index 5ccdcb2dc9d97..4f58233dde730 100644 --- a/providers/tests/amazon/aws/operators/test_mwaa.py +++ b/providers/tests/amazon/aws/operators/test_mwaa.py @@ -33,7 +33,14 @@ "conf": {"key": "value"}, "note": "test note", } -RETURN_VALUE = {"RestApiCode": 200, "RestApiResponse": "test response"} +HOOK_RETURN_VALUE = { + "ResponseMetadata": {}, + "RestApiStatusCode": 200, + "RestApiResponse": { + "dag_run_id": "manual__2025-02-08T00:33:09.457198+00:00", + "other_key": "value", + }, +} class TestMwaaTriggerDagRunOperator: @@ -41,18 +48,18 @@ def test_init(self): op = MwaaTriggerDagRunOperator(**OP_KWARGS) assert op.env_name == OP_KWARGS["env_name"] assert op.trigger_dag_id == OP_KWARGS["trigger_dag_id"] - assert op.trigger_run_id is OP_KWARGS["trigger_run_id"] - assert op.logical_date is OP_KWARGS["logical_date"] - assert op.data_interval_start is OP_KWARGS["data_interval_start"] - assert op.data_interval_end is OP_KWARGS["data_interval_end"] + assert op.trigger_run_id == OP_KWARGS["trigger_run_id"] + assert op.logical_date == OP_KWARGS["logical_date"] + assert op.data_interval_start == OP_KWARGS["data_interval_start"] + assert op.data_interval_end == OP_KWARGS["data_interval_end"] assert op.conf == OP_KWARGS["conf"] - assert op.note is OP_KWARGS["note"] + assert op.note == OP_KWARGS["note"] @mock.patch.object(MwaaTriggerDagRunOperator, "hook") def test_execute(self, mock_hook): - mock_hook.invoke_rest_api.return_value = RETURN_VALUE + mock_hook.invoke_rest_api.return_value = HOOK_RETURN_VALUE op = MwaaTriggerDagRunOperator(**OP_KWARGS) - ret_val = op.execute({}) + op_ret_val = op.execute({}) mock_hook.invoke_rest_api.assert_called_once_with( env_name=OP_KWARGS["env_name"], @@ -67,7 +74,7 @@ def test_execute(self, mock_hook): "note": OP_KWARGS["note"], }, ) - assert ret_val == RETURN_VALUE + assert op_ret_val == HOOK_RETURN_VALUE def test_template_fields(self): operator = MwaaTriggerDagRunOperator(**OP_KWARGS) diff --git a/providers/tests/system/amazon/aws/example_mwaa.py b/providers/tests/system/amazon/aws/example_mwaa.py index 5668174464dc1..206da01d52af4 100644 --- a/providers/tests/system/amazon/aws/example_mwaa.py +++ b/providers/tests/system/amazon/aws/example_mwaa.py @@ -22,7 +22,7 @@ from airflow.models.dag import DAG from airflow.providers.amazon.aws.operators.mwaa import MwaaTriggerDagRunOperator -from providers.tests.system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder +from providers.tests.system.amazon.aws.utils import SystemTestContextBuilder DAG_ID = "example_mwaa" @@ -57,7 +57,6 @@ catchup=False, ) as dag: test_context = sys_test_context_task() - env_id = test_context[ENV_ID_KEY] env_name = test_context[EXISTING_ENVIRONMENT_NAME_KEY] trigger_dag_id = test_context[EXISTING_DAG_ID_KEY] From 0836fc6e930ceaa54f8b7566cbd133410c779938 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Mon, 10 Feb 2025 17:04:13 -0800 Subject: [PATCH 13/19] Minor improvement addressing PR comments --- .../operators/mwaa.rst | 4 + .../providers/amazon/aws/hooks/mwaa.py | 4 + providers/tests/amazon/aws/hooks/test_mwaa.py | 84 +++++++++---------- 3 files changed, 50 insertions(+), 42 deletions(-) diff --git a/docs/apache-airflow-providers-amazon/operators/mwaa.rst b/docs/apache-airflow-providers-amazon/operators/mwaa.rst index cbe6944ce988f..4092a6705a597 100644 --- a/docs/apache-airflow-providers-amazon/operators/mwaa.rst +++ b/docs/apache-airflow-providers-amazon/operators/mwaa.rst @@ -45,6 +45,10 @@ Trigger a DAG run in an Amazon MWAA environment To trigger a DAG run in an Amazon MWAA environment you can use the :class:`~airflow.providers.amazon.aws.operators.mwaa.MwaaTriggerDagRunOperator` +Note: Unlike :class:`~airflow.providers.standard.operators.trigger_dagrun.TriggerDagRunOperator`, this operator is capable of +triggering a DAG in a separate Airflow environment as long as the environment with the DAG being triggered is running on +AWS MWAA. + In the following example, the task ``trigger_dag_run`` triggers a dag run for a DAG with with the ID ``hello_world`` in the environment ``MyAirflowEnvironment``. diff --git a/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py b/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py index 85218db496462..bfe7895138f80 100644 --- a/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py +++ b/providers/src/airflow/providers/amazon/aws/hooks/mwaa.py @@ -71,10 +71,14 @@ def invoke_rest_api( } try: result = self.get_conn().invoke_rest_api(**api_kwargs) + # ResponseMetadata is removed because it contains data that is either very unlikely to be useful + # in XComs and logs, or redundant given the data already included in the response result.pop("ResponseMetadata", None) return result except ClientError as e: to_log = e.response + # ResponseMetadata and Error are removed because they contain data that is either very unlikely to + # be useful in XComs and logs, or redundant given the data already included in the response to_log.pop("ResponseMetadata", None) to_log.pop("Error", None) self.log.error(to_log) diff --git a/providers/tests/amazon/aws/hooks/test_mwaa.py b/providers/tests/amazon/aws/hooks/test_mwaa.py index 77af887be40ff..fe757fa6e16cf 100644 --- a/providers/tests/amazon/aws/hooks/test_mwaa.py +++ b/providers/tests/amazon/aws/hooks/test_mwaa.py @@ -34,6 +34,48 @@ class TestMwaaHook: def setup_method(self): self.hook = MwaaHook() + # these examples responses are included here instead of as a constant because the hook will mutate + # responses causing subsequent tests to fail + self.example_responses = { + "success": { + "ResponseMetadata": { + "RequestId": "some ID", + "HTTPStatusCode": 200, + "HTTPHeaders": {"header1": "value1"}, + "RetryAttempts": 0, + }, + "RestApiStatusCode": 200, + "RestApiResponse": { + "conf": {}, + "dag_id": "hello_world", + "dag_run_id": "manual__2025-02-08T00:33:09.457198+00:00", + "data_interval_end": "2025-02-08T00:33:09.457198+00:00", + "data_interval_start": "2025-02-08T00:33:09.457198+00:00", + "execution_date": "2025-02-08T00:33:09.457198+00:00", + "external_trigger": True, + "logical_date": "2025-02-08T00:33:09.457198+00:00", + "run_type": "manual", + "state": "queued", + }, + }, + "failure": { + "Error": {"Message": "", "Code": "RestApiClientException"}, + "ResponseMetadata": { + "RequestId": "some ID", + "HTTPStatusCode": 400, + "HTTPHeaders": {"header1": "value1"}, + "RetryAttempts": 0, + }, + "RestApiStatusCode": 404, + "RestApiResponse": { + "detail": "DAG with dag_id: 'hello_world1' not found", + "status": 404, + "title": "DAG not found", + "type": "https://airflow.apache.org/docs/apache-airflow/2.10.3/stable-rest-api-ref.html#section/Errors/NotFound", + }, + }, + } + def test_init(self): assert self.hook.client_type == "mwaa" @@ -86,45 +128,3 @@ def test_invoke_rest_api_failure(self, mock_conn) -> None: if k != "ResponseMetadata" and k != "Error" } mock_log.assert_called_once_with(expected_log) - - @pytest.fixture(autouse=True) - def _setup_test_cases(self): - self.example_responses = { - "success": { - "ResponseMetadata": { - "RequestId": "some ID", - "HTTPStatusCode": 200, - "HTTPHeaders": {"header1": "value1"}, - "RetryAttempts": 0, - }, - "RestApiStatusCode": 200, - "RestApiResponse": { - "conf": {}, - "dag_id": "hello_world", - "dag_run_id": "manual__2025-02-08T00:33:09.457198+00:00", - "data_interval_end": "2025-02-08T00:33:09.457198+00:00", - "data_interval_start": "2025-02-08T00:33:09.457198+00:00", - "execution_date": "2025-02-08T00:33:09.457198+00:00", - "external_trigger": True, - "logical_date": "2025-02-08T00:33:09.457198+00:00", - "run_type": "manual", - "state": "queued", - }, - }, - "failure": { - "Error": {"Message": "", "Code": "RestApiClientException"}, - "ResponseMetadata": { - "RequestId": "some ID", - "HTTPStatusCode": 400, - "HTTPHeaders": {"header1": "value1"}, - "RetryAttempts": 0, - }, - "RestApiStatusCode": 404, - "RestApiResponse": { - "detail": "DAG with dag_id: 'hello_world1' not found", - "status": 404, - "title": "DAG not found", - "type": "https://airflow.apache.org/docs/apache-airflow/2.10.3/stable-rest-api-ref.html#section/Errors/NotFound", - }, - }, - } From b3ce918ca1b6fa49b70d3fe925808fadc00913fc Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Mon, 10 Feb 2025 17:59:54 -0800 Subject: [PATCH 14/19] Rewrite the comment about existing MWAA environment requirement for system test --- providers/tests/system/amazon/aws/example_mwaa.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/providers/tests/system/amazon/aws/example_mwaa.py b/providers/tests/system/amazon/aws/example_mwaa.py index 206da01d52af4..3cb884e9b6b5b 100644 --- a/providers/tests/system/amazon/aws/example_mwaa.py +++ b/providers/tests/system/amazon/aws/example_mwaa.py @@ -36,14 +36,13 @@ # NOTE: Creating a functional MWAA environment is time-consuming and requires # manually creating and configuring an S3 bucket for DAG storage and a VPC with # private subnets which is out of scope for this demo. To simplify this demo and - # make it run in a reasonable time, follow these steps in the AWS Console to create - # a new MWAA environment with default configuration: - # 1. Create an S3 bucket and upload your DAGs to a 'dags' directory - # 2. Navigate to the MWAA console - # 3. Create an environment, making sure to use the S3 bucket from the previous step - # 4. Use the default VPC/network settings (or create a new one using this guide: - # https://docs.aws.amazon.com/mwaa/latest/userguide/vpc-create.html) - # 5. Select Public network for web server access and click through to creation + # make it run in a reasonable time, an existing MWAA environment with is required + # Here's a quick start guide to create an MWAA environment using AWS CloudFormation: + # https://docs.aws.amazon.com/mwaa/latest/userguide/quick-start.html + # If creating the environment using the AWS console, make sure to have a VPC with + # at least 1 private subnet to be able to select the VPC while going through the + # environment creation steps in the console wizard. + # Make sure to set the follow environment variables with appropriate values .add_variable(EXISTING_ENVIRONMENT_NAME_KEY) .add_variable(EXISTING_DAG_ID_KEY) .build() From 1ddb1f3d2645fabd1924cb81c30f988770e47e8d Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Tue, 11 Feb 2025 16:11:57 -0800 Subject: [PATCH 15/19] Remove the irrelevant line in comments --- .../src/airflow/providers/amazon/aws/operators/mwaa.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py index c77fa61c676e9..a43ce282703b5 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py @@ -39,11 +39,10 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]): :param env_name: The MWAA environment name (templated) :param trigger_dag_id: The ID of the DAG to be triggered (templated) - :param trigger_run_id: The Run ID. The value of this field can be set only when creating the object. This - together with trigger_dag_id are a unique key. (templated) + :param trigger_run_id: The Run ID. This together with trigger_dag_id are a unique key. (templated) :param logical_date: The logical date (previously called execution date). This is the time or interval - covered by this DAG run, according to the DAG definition. The value of this field can be set only when - creating the object. This together with trigger_dag_id are a unique key. (templated) + covered by this DAG run, according to the DAG definition. This together with trigger_dag_id are a + unique key. (templated) :param data_interval_start: The beginning of the interval the DAG run covers :param data_interval_end: The end of the interval the DAG run covers :param conf: Additional configuration parameters. The value of this field can be set only when creating From 56cc61cc449828460683b9e1ca9460bc868b5288 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Tue, 11 Feb 2025 16:15:49 -0800 Subject: [PATCH 16/19] Use cached conn property instead of get_conn() --- .../amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py | 2 +- .../tests/provider_tests/amazon/aws/hooks/test_mwaa.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py index bfe7895138f80..d7f01238e6ab8 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py @@ -70,7 +70,7 @@ def invoke_rest_api( "QueryParameters": query_params if query_params else {}, } try: - result = self.get_conn().invoke_rest_api(**api_kwargs) + result = self.conn.invoke_rest_api(**api_kwargs) # ResponseMetadata is removed because it contains data that is either very unlikely to be useful # in XComs and logs, or redundant given the data already included in the response result.pop("ResponseMetadata", None) diff --git a/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py b/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py index fe757fa6e16cf..9a3ec5ff66444 100644 --- a/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py +++ b/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py @@ -90,10 +90,10 @@ def test_get_conn(self): {"conf": {}}, # test case: non-empty body ], ) - @mock.patch.object(MwaaHook, "get_conn") + @mock.patch.object(MwaaHook, "conn") def test_invoke_rest_api_success(self, mock_conn, body) -> None: boto_invoke_mock = mock.MagicMock(return_value=self.example_responses["success"]) - mock_conn.return_value.invoke_rest_api = boto_invoke_mock + mock_conn.invoke_rest_api = boto_invoke_mock retval = self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD, body, QUERY_PARAMS) kwargs_to_assert = { @@ -108,13 +108,13 @@ def test_invoke_rest_api_success(self, mock_conn, body) -> None: k: v for k, v in self.example_responses["success"].items() if k != "ResponseMetadata" } - @mock.patch.object(MwaaHook, "get_conn") + @mock.patch.object(MwaaHook, "conn") def test_invoke_rest_api_failure(self, mock_conn) -> None: error = ClientError( error_response=self.example_responses["failure"], operation_name="invoke_rest_api" ) boto_invoke_mock = mock.MagicMock(side_effect=error) - mock_conn.return_value.invoke_rest_api = boto_invoke_mock + mock_conn.invoke_rest_api = boto_invoke_mock mock_log = mock.MagicMock() self.hook.log.error = mock_log From 9e9411b2ac68e2ec41690ed95f5ff20d32348df9 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Tue, 11 Feb 2025 16:34:24 -0800 Subject: [PATCH 17/19] Use pytest.param instead of code comments --- .../amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py b/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py index 9a3ec5ff66444..4ed4c477cdb66 100644 --- a/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py +++ b/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py @@ -86,8 +86,8 @@ def test_get_conn(self): @pytest.mark.parametrize( "body", [ - None, # test case: empty body - {"conf": {}}, # test case: non-empty body + pytest.param(None, id="no_body"), + pytest.param({"conf": {}}, id="non_empty_body"), ], ) @mock.patch.object(MwaaHook, "conn") From f54d404c30f9cd4a90e9f929ced4063cd98013c6 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Wed, 12 Feb 2025 09:18:21 -0800 Subject: [PATCH 18/19] Cleanup --- .../amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py b/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py index 4ed4c477cdb66..5d8dc761c3334 100644 --- a/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py +++ b/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py @@ -34,7 +34,7 @@ class TestMwaaHook: def setup_method(self): self.hook = MwaaHook() - # these examples responses are included here instead of as a constant because the hook will mutate + # these example responses are included here instead of as a constant because the hook will mutate # responses causing subsequent tests to fail self.example_responses = { "success": { @@ -81,7 +81,7 @@ def test_init(self): @mock_aws def test_get_conn(self): - assert self.hook.get_conn() is not None + assert self.hook.conn is not None @pytest.mark.parametrize( "body", From a8e219dbd2d8c993ead38d93cf896058057707ff Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Wed, 12 Feb 2025 14:24:27 -0800 Subject: [PATCH 19/19] More cleanup --- providers/amazon/docs/operators/mwaa.rst | 2 +- .../amazon/src/airflow/providers/amazon/aws/operators/mwaa.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/providers/amazon/docs/operators/mwaa.rst b/providers/amazon/docs/operators/mwaa.rst index 4092a6705a597..021998b0a10ed 100644 --- a/providers/amazon/docs/operators/mwaa.rst +++ b/providers/amazon/docs/operators/mwaa.rst @@ -52,7 +52,7 @@ AWS MWAA. In the following example, the task ``trigger_dag_run`` triggers a dag run for a DAG with with the ID ``hello_world`` in the environment ``MyAirflowEnvironment``. -.. exampleinclude:: /../../providers/tests/system/amazon/aws/example_mwaa.py +.. exampleinclude:: /../../providers/amazon/tests/system/amazon/aws/example_mwaa.py :language: python :dedent: 4 :start-after: [START howto_operator_mwaa_trigger_dag_run] diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py index a43ce282703b5..42f1038f2c5cb 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py @@ -62,7 +62,6 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]): "note", ) template_fields_renderers = {"conf": "json"} - ui_color = "#6ad3fa" def __init__( self,