diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index c132e6456f1d8..102d1fe31e5c1 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -39,6 +39,7 @@ from airflow.utils import timezone from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields TEST_DAG_ID = "unit_tests" DEFAULT_DATE = datetime(2018, 1, 1) @@ -397,3 +398,6 @@ def mock_get_table_metadata(CatalogName, DatabaseName, TableName): run_facets={"externalQuery": ExternalQueryRunFacet(externalQueryId="12345", source="awsathena")}, ) assert op.get_openlineage_facets_on_complete(None) == expected_lineage + + def test_template_fields(self): + validate_template_fields(self.athena) diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py b/tests/providers/amazon/aws/operators/test_bedrock.py index b49d09b52a5c0..8cbb67d6f50df 100644 --- a/tests/providers/amazon/aws/operators/test_bedrock.py +++ b/tests/providers/amazon/aws/operators/test_bedrock.py @@ -35,6 +35,7 @@ BedrockInvokeModelOperator, BedrockRaGOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection @@ -176,6 +177,9 @@ def test_ensure_unique_job_name(self, _, side_effect, ensure_unique_name, mock_c bedrock_hook.get_waiter.assert_not_called() self.operator.defer.assert_not_called() + def test_template_fields(self): + validate_template_fields(self.operator) + class TestBedrockCreateProvisionedModelThroughputOperator: MODEL_ARN = "testProvisionedModelArn" @@ -222,6 +226,9 @@ def test_provisioned_model_wait_combinations( assert bedrock_hook.get_waiter.call_count == wait_for_completion assert self.operator.defer.call_count == deferrable + def test_template_fields(self): + validate_template_fields(self.operator) + class TestBedrockCreateKnowledgeBaseOperator: KNOWLEDGE_BASE_ID = "knowledge_base_id" @@ -288,6 +295,9 @@ def test_returns_id(self, mock_conn): assert result == self.KNOWLEDGE_BASE_ID + def test_template_fields(self): + validate_template_fields(self.operator) + class TestBedrockCreateDataSourceOperator: DATA_SOURCE_ID = "data_source_id" @@ -317,6 +327,9 @@ def test_id_returned(self, mock_conn): assert result == self.DATA_SOURCE_ID + def test_template_fields(self): + validate_template_fields(self.operator) + class TestBedrockIngestDataOperator: INGESTION_JOB_ID = "ingestion_job_id" @@ -348,6 +361,9 @@ def test_id_returned(self, mock_conn): assert result == self.INGESTION_JOB_ID + def test_template_fields(self): + validate_template_fields(self.operator) + class TestBedrockRaGOperator: VECTOR_SEARCH_CONFIG = {"filter": {"equals": {"key": "some key", "value": "some value"}}} @@ -520,3 +536,14 @@ def test_external_sources_build_rag_config(self, prompt_template): **expected_config_without_template, **expected_config_template, } + + def test_template_fields(self): + op = BedrockRaGOperator( + task_id="test_rag", + input="some text prompt", + source_type="EXTERNAL_SOURCES", + model_arn=self.MODEL_ARN, + knowledge_base_id=self.KNOWLEDGE_BASE_ID, + vector_search_config=self.VECTOR_SEARCH_CONFIG, + ) + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_cloud_formation.py b/tests/providers/amazon/aws/operators/test_cloud_formation.py index 5de02c3622cfb..4d8fb4d12bd3c 100644 --- a/tests/providers/amazon/aws/operators/test_cloud_formation.py +++ b/tests/providers/amazon/aws/operators/test_cloud_formation.py @@ -28,6 +28,7 @@ CloudFormationDeleteStackOperator, ) from airflow.utils import timezone +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields DEFAULT_DATE = timezone.datetime(2019, 1, 1) DEFAULT_ARGS = {"owner": "airflow", "start_date": DEFAULT_DATE} @@ -87,6 +88,20 @@ def test_create_stack(self, mocked_hook_client): StackName=stack_name, TemplateBody=template_body, TimeoutInMinutes=timeout ) + def test_template_fields(self): + op = CloudFormationCreateStackOperator( + task_id="cf_create_stack_init", + stack_name="fake-stack", + cloudformation_parameters={}, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="eu-west-1", + verify=True, + botocore_config={"read_timeout": 42}, + ) + + validate_template_fields(op) + class TestCloudFormationDeleteStackOperator: def test_init(self): @@ -125,3 +140,16 @@ def test_delete_stack(self, mocked_hook_client): operator.execute(MagicMock()) mocked_hook_client.delete_stack.assert_any_call(StackName=stack_name) + + def test_template_fields(self): + op = CloudFormationDeleteStackOperator( + task_id="cf_delete_stack_init", + stack_name="fake-stack", + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="us-east-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_comprehend.py b/tests/providers/amazon/aws/operators/test_comprehend.py index 60f0fca219111..a86b779b1d502 100644 --- a/tests/providers/amazon/aws/operators/test_comprehend.py +++ b/tests/providers/amazon/aws/operators/test_comprehend.py @@ -29,6 +29,7 @@ ComprehendStartPiiEntitiesDetectionJobOperator, ) from airflow.utils.types import NOTSET +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection @@ -163,6 +164,9 @@ def test_start_pii_entities_detection_job_wait_combinations( assert comprehend_hook.get_waiter.call_count == wait_for_completion assert self.operator.defer.call_count == deferrable + def test_template_fields(self): + validate_template_fields(self.operator) + class TestComprehendCreateDocumentClassifierOperator: CLASSIFIER_ARN = ( @@ -259,3 +263,6 @@ def test_create_document_classifier_wait_combinations( assert response == self.CLASSIFIER_ARN assert comprehend_hook.get_waiter.call_count == wait_for_completion assert self.operator.defer.call_count == deferrable + + def test_template_fields(self): + validate_template_fields(self.operator) diff --git a/tests/providers/amazon/aws/operators/test_datasync.py b/tests/providers/amazon/aws/operators/test_datasync.py index e1a44ce99e28c..18b0e86103c0b 100644 --- a/tests/providers/amazon/aws/operators/test_datasync.py +++ b/tests/providers/amazon/aws/operators/test_datasync.py @@ -29,6 +29,7 @@ from airflow.utils import timezone from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields TEST_DAG_ID = "unit_tests" DEFAULT_DATE = datetime(2018, 1, 1) @@ -363,6 +364,10 @@ def test_return_value(self, mock_get_conn, session, clean_dags_and_dagruns): # ### Check mocks: mock_get_conn.assert_called() + def test_template_fields(self, mock_get_conn): + self.set_up_operator() + validate_template_fields(self.datasync) + @mock_aws @mock.patch.object(DataSyncHook, "get_conn") diff --git a/tests/providers/amazon/aws/operators/test_dms.py b/tests/providers/amazon/aws/operators/test_dms.py index fba14a6370dd7..2528edaef9e0a 100644 --- a/tests/providers/amazon/aws/operators/test_dms.py +++ b/tests/providers/amazon/aws/operators/test_dms.py @@ -34,6 +34,7 @@ ) from airflow.utils import timezone from airflow.utils.types import DagRunType +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields TASK_ARN = "test_arn" @@ -121,6 +122,18 @@ def test_create_task_with_migration_type( assert dms_hook.get_task_status(TASK_ARN) == "ready" + def test_template_fields(self): + op = DmsCreateTaskOperator( + task_id="create_task", + **self.TASK_DATA, + aws_conn_id="fake-conn-id", + region_name="ca-west-1", + verify=True, + botocore_config={"read_timeout": 42}, + ) + + validate_template_fields(op) + class TestDmsDeleteTaskOperator: TASK_DATA = { @@ -174,6 +187,19 @@ def test_delete_task( assert dms_hook.get_task_status(TASK_ARN) == "deleting" + def test_template_fields(self): + op = DmsDeleteTaskOperator( + task_id="delete_task", + replication_task_arn=TASK_ARN, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="us-east-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + + validate_template_fields(op) + class TestDmsDescribeTasksOperator: FILTER = {"Name": "replication-task-arn", "Values": [TASK_ARN]} @@ -267,6 +293,18 @@ def test_describe_tasks_return_value(self, mock_conn, mock_describe_replication_ assert marker is None assert response == self.MOCK_RESPONSE + def test_template_fields(self): + op = DmsDescribeTasksOperator( + task_id="describe_tasks", + describe_tasks_kwargs={"Filters": [self.FILTER]}, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="eu-west-2", + verify="/foo/bar/spam.egg", + botocore_config={"read_timeout": 42}, + ) + validate_template_fields(op) + class TestDmsStartTaskOperator: TASK_DATA = { @@ -324,6 +362,19 @@ def test_start_task( assert dms_hook.get_task_status(TASK_ARN) == "starting" + def test_template_fields(self): + op = DmsStartTaskOperator( + task_id="start_task", + replication_task_arn=TASK_ARN, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="us-west-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + + validate_template_fields(op) + class TestDmsStopTaskOperator: TASK_DATA = { @@ -376,3 +427,16 @@ def test_stop_task( mock_stop_replication_task.assert_called_once_with(replication_task_arn=TASK_ARN) assert dms_hook.get_task_status(TASK_ARN) == "stopping" + + def test_template_fields(self): + op = DmsStopTaskOperator( + task_id="stop_task", + replication_task_arn=TASK_ARN, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="eu-west-1", + verify=True, + botocore_config={"read_timeout": 42}, + ) + + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_ec2.py b/tests/providers/amazon/aws/operators/test_ec2.py index 8f8a755a84357..a5ea81ff6ae87 100644 --- a/tests/providers/amazon/aws/operators/test_ec2.py +++ b/tests/providers/amazon/aws/operators/test_ec2.py @@ -30,6 +30,7 @@ EC2StopInstanceOperator, EC2TerminateInstanceOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields class BaseEc2TestClass: @@ -87,6 +88,13 @@ def test_create_multiple_instances(self): for id in instance_ids: assert ec2_hook.get_instance_state(instance_id=id) == "running" + def test_template_fields(self): + ec2_operator = EC2CreateInstanceOperator( + task_id="test_create_instance", + image_id="test_image_id", + ) + validate_template_fields(ec2_operator) + class TestEC2TerminateInstanceOperator(BaseEc2TestClass): def test_init(self): @@ -140,6 +148,13 @@ def test_terminate_multiple_instances(self): for id in instance_ids: assert ec2_hook.get_instance_state(instance_id=id) == "terminated" + def test_template_fields(self): + ec2_operator = EC2TerminateInstanceOperator( + task_id="test_terminate_instance", + instance_ids="test_image_id", + ) + validate_template_fields(ec2_operator) + class TestEC2StartInstanceOperator(BaseEc2TestClass): def test_init(self): @@ -175,6 +190,17 @@ def test_start_instance(self): # assert instance state is running assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running" + def test_template_fields(self): + ec2_operator = EC2StartInstanceOperator( + task_id="task_test", + instance_id="i-123abc", + aws_conn_id="aws_conn_test", + region_name="region-test", + check_interval=3, + ) + + validate_template_fields(ec2_operator) + class TestEC2StopInstanceOperator(BaseEc2TestClass): def test_init(self): @@ -210,6 +236,17 @@ def test_stop_instance(self): # assert instance state is running assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "stopped" + def test_template_fields(self): + ec2_operator = EC2StopInstanceOperator( + task_id="task_test", + instance_id="i-123abc", + aws_conn_id="aws_conn_test", + region_name="region-test", + check_interval=3, + ) + + validate_template_fields(ec2_operator) + class TestEC2HibernateInstanceOperator(BaseEc2TestClass): def test_init(self): @@ -322,6 +359,13 @@ def test_cannot_hibernate_some_instances(self): for id in instance_ids: assert ec2_hook.get_instance_state(instance_id=id) == "running" + def test_template_fields(self): + ec2_operator = EC2HibernateInstanceOperator( + task_id="task_test", + instance_ids="i-123abc", + ) + validate_template_fields(ec2_operator) + class TestEC2RebootInstanceOperator(BaseEc2TestClass): def test_init(self): @@ -372,3 +416,10 @@ def test_reboot_multiple_instances(self): terminate_instance.execute(None) for id in instance_ids: assert ec2_hook.get_instance_state(instance_id=id) == "running" + + def test_template_fields(self): + ec2_operator = EC2RebootInstanceOperator( + task_id="task_test", + instance_ids="i-123abc", + ) + validate_template_fields(ec2_operator) diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index a6915214a0764..be06a8802e449 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -39,6 +39,7 @@ from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.types import NOTSET +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields CLUSTER_NAME = "test_cluster" CONTAINER_NAME = "e1ed7aac-d9b2-4315-8726-d2432bf11868" @@ -793,6 +794,17 @@ def test_execute_without_waiter(self, patch_hook_waiters): patch_hook_waiters.assert_not_called() assert result is not None + def test_template_fields(self): + op = EcsCreateClusterOperator( + task_id="task", + cluster_name=CLUSTER_NAME, + deferrable=True, + waiter_delay=12, + waiter_max_attempts=34, + ) + + validate_template_fields(op) + class TestEcsDeleteClusterOperator(EcsBaseTestCase): @pytest.mark.parametrize("waiter_delay, waiter_max_attempts", WAITERS_TEST_CASES) @@ -857,6 +869,17 @@ def test_execute_without_waiter(self, patch_hook_waiters): patch_hook_waiters.assert_not_called() assert result is not None + def test_template_fields(self): + op = EcsDeleteClusterOperator( + task_id="task", + cluster_name=CLUSTER_NAME, + deferrable=True, + waiter_delay=12, + waiter_max_attempts=34, + ) + + validate_template_fields(op) + class TestEcsDeregisterTaskDefinitionOperator(EcsBaseTestCase): warn_message = "'wait_for_completion' and waiter related params have no effect" @@ -913,6 +936,11 @@ def test_partial_deprecation_waiters_params( assert not hasattr(ti.task, "waiter_delay") assert not hasattr(ti.task, "waiter_max_attempts") + def test_template_fields(self): + op = EcsDeregisterTaskDefinitionOperator(task_id="task", task_definition=TASK_DEFINITION_NAME) + + validate_template_fields(op) + class TestEcsRegisterTaskDefinitionOperator(EcsBaseTestCase): warn_message = "'wait_for_completion' and waiter related params have no effect" @@ -990,3 +1018,8 @@ def test_partial_deprecation_waiters_params( assert not hasattr(ti.task, "wait_for_completion") assert not hasattr(ti.task, "waiter_delay") assert not hasattr(ti.task, "waiter_max_attempts") + + def test_template_fields(self): + op = EcsRegisterTaskDefinitionOperator(task_id="task", **TASK_DEFINITION_CONFIG) + + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_eks.py b/tests/providers/amazon/aws/operators/test_eks.py index 9571ca0962005..399c8e40823ae 100644 --- a/tests/providers/amazon/aws/operators/test_eks.py +++ b/tests/providers/amazon/aws/operators/test_eks.py @@ -51,6 +51,7 @@ TASK_ID, ) from tests.providers.amazon.aws.utils.eks_test_utils import convert_keys +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields from tests.providers.amazon.aws.utils.test_waiter import assert_expected_waiter_type CLUSTER_NAME = "cluster1" @@ -365,6 +366,15 @@ def test_eks_create_cluster_with_deferrable(self, mock_create_cluster, caplog): eks_create_cluster_operator.execute({}) assert "Waiting for EKS Cluster to provision. This will take some time." in caplog.messages + def test_template_fields(self): + op = EksCreateClusterOperator( + task_id=TASK_ID, + **self.create_cluster_params, + compute="fargate", + ) + + validate_template_fields(op) + class TestEksCreateFargateProfileOperator: def setup_method(self) -> None: @@ -445,6 +455,11 @@ def test_create_fargate_profile_deferrable(self, _): exc.value.trigger, EksCreateFargateProfileTrigger ), "Trigger is not a EksCreateFargateProfileTrigger" + def test_template_fields(self): + op = EksCreateFargateProfileOperator(task_id=TASK_ID, **self.create_fargate_profile_params) + + validate_template_fields(op) + class TestEksCreateNodegroupOperator: def setup_method(self) -> None: @@ -536,6 +551,12 @@ def test_create_nodegroup_deferrable_versus_wait_for_completion(self): ) assert operator.wait_for_completion is True + def test_template_fields(self): + op_kwargs = {**self.create_nodegroup_params} + op = EksCreateNodegroupOperator(task_id=TASK_ID, **op_kwargs) + + validate_template_fields(op) + class TestEksDeleteClusterOperator: def setup_method(self) -> None: @@ -575,6 +596,9 @@ def test_eks_delete_cluster_operator_with_deferrable(self): with pytest.raises(TaskDeferred): self.delete_cluster_operator.execute({}) + def test_template_fields(self): + validate_template_fields(self.delete_cluster_operator) + class TestEksDeleteNodegroupOperator: def setup_method(self) -> None: @@ -608,6 +632,9 @@ def test_existing_nodegroup_with_wait(self, mock_delete_nodegroup, mock_waiter): mock_waiter.assert_called_with(mock.ANY, clusterName=CLUSTER_NAME, nodegroupName=NODEGROUP_NAME) assert_expected_waiter_type(mock_waiter, "NodegroupDeleted") + def test_template_fields(self): + validate_template_fields(self.delete_nodegroup_operator) + class TestEksDeleteFargateProfileOperator: def setup_method(self) -> None: @@ -656,6 +683,9 @@ def test_delete_fargate_profile_deferrable(self, _): exc.value.trigger, EksDeleteFargateProfileTrigger ), "Trigger is not a EksDeleteFargateProfileTrigger" + def test_template_fields(self): + validate_template_fields(self.delete_fargate_profile_operator) + class TestEksPodOperator: @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.execute") @@ -767,3 +797,17 @@ def test_on_finish_action_handler( ) for expected_attr in expected_attributes: assert op.__getattribute__(expected_attr) == expected_attributes[expected_attr] + + def test_template_fields(self): + op = EksPodOperator( + task_id="run_pod", + pod_name="run_pod", + cluster_name=CLUSTER_NAME, + image="amazon/aws-cli:latest", + cmds=["sh", "-c", "ls"], + labels={"demo": "hello_world"}, + get_logs=True, + on_finish_action="delete_pod", + ) + + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py index 9ee99864e00e3..d5a999349aa53 100644 --- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py +++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py @@ -31,6 +31,7 @@ from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger from airflow.utils import timezone from airflow.utils.types import DagRunType +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields from tests.test_utils import AIRFLOW_MAIN_FOLDER DEFAULT_DATE = timezone.datetime(2017, 1, 1) @@ -274,3 +275,12 @@ def test_emr_add_steps_deferrable(self, mock_add_job_flow_steps, mock_get_log_ur operator.execute(MagicMock()) assert isinstance(exc.value.trigger, EmrAddStepsTrigger), "Trigger is not a EmrAddStepsTrigger" + + def test_template_fields(self): + op = EmrAddStepsOperator( + task_id="test_task", + job_flow_id="j-8989898989", + aws_conn_id="aws_default", + steps=self._config, + ) + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py index feeec1278e155..52306864f3597 100644 --- a/tests/providers/amazon/aws/operators/test_emr_containers.py +++ b/tests/providers/amazon/aws/operators/test_emr_containers.py @@ -25,6 +25,7 @@ from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator, EmrEksCreateClusterOperator from airflow.providers.amazon.aws.triggers.emr import EmrContainerTrigger +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields SUBMIT_JOB_SUCCESS_RETURN = { "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -194,3 +195,6 @@ def test_emr_on_eks_execute_with_failure(self, mock_create_emr_on_eks_cluster): with pytest.raises(AirflowException) as ctx: self.emr_container.execute(None) assert expected_exception_msg in str(ctx.value) + + def test_template_fields(self): + validate_template_fields(self.emr_container) diff --git a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py index 204d292c67b46..860df8c7219ac 100644 --- a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py +++ b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py @@ -32,6 +32,7 @@ from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger from airflow.utils import timezone from airflow.utils.types import DagRunType +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields from tests.providers.amazon.aws.utils.test_waiter import assert_expected_waiter_type from tests.test_utils import AIRFLOW_MAIN_FOLDER @@ -203,3 +204,6 @@ def test_create_job_flow_deferrable(self, mocked_hook_client): assert isinstance( exc.value.trigger, EmrCreateJobFlowTrigger ), "Trigger is not a EmrCreateJobFlowTrigger" + + def test_template_fields(self): + validate_template_fields(self.operator) diff --git a/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py b/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py index 6dada442ff79f..6f257288760c3 100644 --- a/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py +++ b/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py @@ -25,6 +25,7 @@ from airflow.models.dag import DAG from airflow.providers.amazon.aws.operators.emr import EmrModifyClusterOperator from airflow.utils import timezone +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields DEFAULT_DATE = timezone.datetime(2017, 1, 1) MODIFY_CLUSTER_SUCCESS_RETURN = {"ResponseMetadata": {"HTTPStatusCode": 200}, "StepConcurrencyLevel": 1} @@ -65,3 +66,6 @@ def test_execute_returns_error(self, mocked_hook_client): with pytest.raises(AirflowException, match="Modify cluster failed"): self.operator.execute(self.mock_context) + + def test_template_fields(self): + validate_template_fields(self.operator) diff --git a/tests/providers/amazon/aws/operators/test_emr_notebook_execution.py b/tests/providers/amazon/aws/operators/test_emr_notebook_execution.py index ef6cb7ebc70ec..6fcd4eeb74629 100644 --- a/tests/providers/amazon/aws/operators/test_emr_notebook_execution.py +++ b/tests/providers/amazon/aws/operators/test_emr_notebook_execution.py @@ -28,6 +28,7 @@ EmrStartNotebookExecutionOperator, EmrStopNotebookExecutionOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields from tests.providers.amazon.aws.utils.test_waiter import assert_expected_waiter_type PARAMS = { @@ -303,3 +304,20 @@ def test_stop_notebook_execution_waiter_config(self, mock_conn, mock_waiter, _): WaiterConfig={"Delay": delay, "MaxAttempts": waiter_max_attempts}, ) assert_expected_waiter_type(mock_waiter, "notebook_stopped") + + def test_template_fields(self): + op = EmrStartNotebookExecutionOperator( + task_id="test-id", + editor_id=PARAMS["EditorId"], + relative_path=PARAMS["RelativePath"], + cluster_id=PARAMS["ExecutionEngine"]["Id"], + service_role=PARAMS["ServiceRole"], + notebook_execution_name=PARAMS["NotebookExecutionName"], + notebook_params=PARAMS["NotebookParams"], + notebook_instance_security_group_id=PARAMS["NotebookInstanceSecurityGroupId"], + master_instance_security_group_id=PARAMS["ExecutionEngine"]["MasterInstanceSecurityGroupId"], + tags=PARAMS["Tags"], + wait_for_completion=True, + ) + + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py index 12c5cc938018e..e7a43cf079f0b 100644 --- a/tests/providers/amazon/aws/operators/test_emr_serverless.py +++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py @@ -32,6 +32,7 @@ EmrServerlessStopApplicationOperator, ) from airflow.utils.types import NOTSET +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields if TYPE_CHECKING: from unittest.mock import MagicMock @@ -393,6 +394,25 @@ def test_create_application_deferrable(self, mock_conn): with pytest.raises(TaskDeferred): operator.execute(None) + def test_template_fields(self): + operator = EmrServerlessCreateApplicationOperator( + task_id=task_id, + release_label=release_label, + job_type=job_type, + client_request_token=client_request_token, + config=config, + waiter_max_attempts=3, + waiter_delay=0, + ) + + template_fields = list(operator.template_fields) + list(operator.template_fields_renderers.keys()) + + class_fields = operator.__dict__ + + missing_fields = [field for field in template_fields if field not in class_fields] + + assert not missing_fields, f"Templated fields are not available {missing_fields}" + class TestEmrServerlessStartJobOperator: def setup_method(self): @@ -1163,6 +1183,24 @@ def test_links_spark_without_applicationui_enabled( job_run_id=job_run_id, ) + def test_template_fields(self): + operator = EmrServerlessStartJobOperator( + task_id=task_id, + client_request_token=client_request_token, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=configuration_overrides, + ) + + template_fields = list(operator.template_fields) + list(operator.template_fields_renderers.keys()) + + class_fields = operator.__dict__ + + missing_fields = [field for field in template_fields if field not in class_fields] + + assert not missing_fields, f"Templated fields are not available {missing_fields}" + class TestEmrServerlessDeleteOperator: @mock.patch.object(EmrServerlessHook, "get_waiter") @@ -1277,6 +1315,19 @@ def test_delete_application_deferrable(self, mock_conn): with pytest.raises(TaskDeferred): operator.execute(None) + def test_template_fields(self): + operator = EmrServerlessDeleteApplicationOperator( + task_id=task_id, application_id=application_id_delete_operator + ) + + template_fields = list(operator.template_fields) + list(operator.template_fields_renderers.keys()) + + class_fields = operator.__dict__ + + missing_fields = [field for field in template_fields if field not in class_fields] + + assert not missing_fields, f"Templated fields are not available {missing_fields}" + class TestEmrServerlessStopOperator: @mock.patch.object(EmrServerlessHook, "get_waiter") @@ -1344,3 +1395,10 @@ def test_stop_application_deferrable_without_force_stop( operator.execute({}) assert "no running jobs found with application ID test" in caplog.messages + + def test_template_fields(self): + operator = EmrServerlessStopApplicationOperator( + task_id=task_id, application_id="test", deferrable=True, force_stop=True + ) + + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py index 2c27c146d2d34..06ab35e4510ba 100644 --- a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py +++ b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py @@ -24,6 +24,7 @@ from airflow.exceptions import TaskDeferred from airflow.providers.amazon.aws.operators.emr import EmrTerminateJobFlowOperator from airflow.providers.amazon.aws.triggers.emr import EmrTerminateJobFlowTrigger +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields TERMINATE_SUCCESS_RETURN = {"ResponseMetadata": {"HTTPStatusCode": 200}} @@ -57,3 +58,13 @@ def test_create_job_flow_deferrable(self, mocked_hook_client): assert isinstance( exc.value.trigger, EmrTerminateJobFlowTrigger ), "Trigger is not a EmrTerminateJobFlowTrigger" + + def test_template_fields(self): + operator = EmrTerminateJobFlowOperator( + task_id="test_task", + job_flow_id="j-8989898989", + aws_conn_id="aws_default", + deferrable=True, + ) + + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/utils/test_template_fields.py b/tests/providers/amazon/aws/utils/test_template_fields.py new file mode 100644 index 0000000000000..689977de9bcc5 --- /dev/null +++ b/tests/providers/amazon/aws/utils/test_template_fields.py @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + + +def validate_template_fields(operator): + template_fields = list(operator.template_fields) + list(operator.template_fields_renderers.keys()) + + class_fields = operator.__dict__ + + missing_fields = [field for field in template_fields if field not in class_fields] + + assert not missing_fields, f"Templated fields are not available {missing_fields}"