From 3e7b402d97bdfe811c3b5023671252eea3a0dc3b Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Thu, 12 Sep 2024 18:14:30 +0100 Subject: [PATCH] Added template fields tests to remaining aws operators --- .../amazon/aws/operators/test_eventbridge.py | 28 +++++ .../amazon/aws/operators/test_glacier.py | 9 ++ .../amazon/aws/operators/test_glue.py | 24 ++++ .../amazon/aws/operators/test_glue_crawler.py | 4 + .../aws/operators/test_glue_databrew.py | 5 + .../aws/operators/test_kinesis_analytics.py | 16 +++ .../aws/operators/test_lambda_function.py | 19 +++ .../amazon/aws/operators/test_neptune.py | 21 ++++ .../amazon/aws/operators/test_quicksight.py | 5 + .../amazon/aws/operators/test_rds.py | 111 ++++++++++++++++++ .../aws/operators/test_redshift_cluster.py | 53 +++++++++ .../aws/operators/test_redshift_data.py | 12 ++ .../providers/amazon/aws/operators/test_s3.py | 61 ++++++++++ .../aws/operators/test_sagemaker_base.py | 9 ++ .../aws/operators/test_sagemaker_endpoint.py | 4 + .../test_sagemaker_endpoint_config.py | 4 + .../aws/operators/test_sagemaker_model.py | 18 +++ .../aws/operators/test_sagemaker_notebook.py | 28 +++++ .../aws/operators/test_sagemaker_pipeline.py | 16 +++ .../operators/test_sagemaker_processing.py | 8 ++ .../aws/operators/test_sagemaker_training.py | 4 + .../aws/operators/test_sagemaker_transform.py | 4 + .../aws/operators/test_sagemaker_tuning.py | 4 + .../amazon/aws/operators/test_sns.py | 5 + .../amazon/aws/operators/test_sqs.py | 7 ++ .../aws/operators/test_step_function.py | 20 ++++ 26 files changed, 499 insertions(+) diff --git a/tests/providers/amazon/aws/operators/test_eventbridge.py b/tests/providers/amazon/aws/operators/test_eventbridge.py index 4dcd068a8155c..3c682f1477e03 100644 --- a/tests/providers/amazon/aws/operators/test_eventbridge.py +++ b/tests/providers/amazon/aws/operators/test_eventbridge.py @@ -29,6 +29,7 @@ EventBridgePutEventsOperator, EventBridgePutRuleOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields if TYPE_CHECKING: from unittest.mock import MagicMock @@ -96,6 +97,13 @@ def test_failed_to_send(self, mock_conn: MagicMock): with pytest.raises(AirflowException): operator.execute(context={}) + def test_template_fields(self): + operator = EventBridgePutEventsOperator( + task_id="failed_put_events_job", + entries=ENTRIES, + ) + validate_template_fields(operator) + class TestEventBridgePutRuleOperator: def test_init(self): @@ -150,6 +158,12 @@ def test_put_rule_with_bad_json_fails(self): with pytest.raises(ValueError): operator.execute(None) + def test_template_fields(self): + operator = EventBridgePutRuleOperator( + task_id="events_put_rule_job", name=RULE_NAME, event_pattern=EVENT_PATTERN + ) + validate_template_fields(operator) + class TestEventBridgeEnableRuleOperator: def test_init(self): @@ -186,6 +200,13 @@ def test_enable_rule(self, mock_conn: MagicMock): enable_rule.execute(context={}) mock_conn.enable_rule.assert_called_with(Name=RULE_NAME) + def test_template_fields(self): + operator = EventBridgeEnableRuleOperator( + task_id="events_enable_rule_job", + name=RULE_NAME, + ) + validate_template_fields(operator) + class TestEventBridgeDisableRuleOperator: def test_init(self): @@ -221,3 +242,10 @@ def test_disable_rule(self, mock_conn: MagicMock): disable_rule.execute(context={}) mock_conn.disable_rule.assert_called_with(Name=RULE_NAME) + + def test_template_fields(self): + operator = EventBridgeDisableRuleOperator( + task_id="events_disable_rule_job", + name=RULE_NAME, + ) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_glacier.py b/tests/providers/amazon/aws/operators/test_glacier.py index 4dbd8f2f5a42b..f46b0bc929fe8 100644 --- a/tests/providers/amazon/aws/operators/test_glacier.py +++ b/tests/providers/amazon/aws/operators/test_glacier.py @@ -26,6 +26,7 @@ GlacierCreateJobOperator, GlacierUploadArchiveOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields if TYPE_CHECKING: from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator @@ -78,6 +79,10 @@ def test_execute(self, hook_mock): op.execute(mock.MagicMock()) hook_mock.return_value.retrieve_inventory.assert_called_once_with(vault_name=VAULT_NAME) + def test_template_fields(self): + operator = self.op_class(**self.default_op_kwargs) + validate_template_fields(operator) + class TestGlacierUploadArchiveOperator(BaseGlacierOperatorsTests): op_class = GlacierUploadArchiveOperator @@ -97,3 +102,7 @@ def test_execute(self): body=b"Test Data", checksum=None, ) + + def test_template_fields(self): + operator = self.op_class(**self.default_op_kwargs) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_glue.py b/tests/providers/amazon/aws/operators/test_glue.py index 6c33dadb898cf..e1adcee7d639f 100644 --- a/tests/providers/amazon/aws/operators/test_glue.py +++ b/tests/providers/amazon/aws/operators/test_glue.py @@ -34,6 +34,7 @@ GlueDataQualityRuleSetEvaluationRunOperator, GlueJobOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields if TYPE_CHECKING: from airflow.models import TaskInstance @@ -307,6 +308,17 @@ def test_replace_script_file( "folder/file", "artifacts/glue-scripts/file", bucket_name="bucket_name", replace=True ) + def test_template_fields(self): + operator = GlueJobOperator( + task_id=TASK_ID, + job_name=JOB_NAME, + script_location="folder/file", + s3_bucket="bucket_name", + iam_role_name="role_arn", + replace_script_file=True, + ) + validate_template_fields(operator) + class TestGlueDataQualityOperator: RULE_SET_NAME = "TestRuleSet" @@ -435,6 +447,12 @@ def test_validate_inputs_error(self): with pytest.raises(AttributeError, match="RuleSet must starts with Rules = \\[ and ends with \\]"): self.operator.validate_inputs() + def test_template_fields(self): + operator = GlueDataQualityOperator( + task_id="create_data_quality_ruleset", name=self.RULE_SET_NAME, ruleset=self.RULE_SET + ) + validate_template_fields(operator) + class TestGlueDataQualityRuleSetEvaluationRunOperator: RUN_ID = "1234567890" @@ -538,6 +556,9 @@ def test_start_data_quality_ruleset_evaluation_run_wait_combinations( assert glue_data_quality_hook.get_waiter.call_count == wait_for_completion assert self.operator.defer.call_count == deferrable + def test_template_fields(self): + validate_template_fields(self.operator) + class TestGlueDataQualityRuleRecommendationRunOperator: RUN_ID = "1234567890" @@ -643,3 +664,6 @@ def test_start_data_quality_rule_recommendation_run_wait_combinations( assert response == self.RUN_ID assert glue_data_quality_hook.get_waiter.call_count == wait_for_completion assert self.operator.defer.call_count == deferrable + + def test_template_fields(self): + validate_template_fields(self.operator) diff --git a/tests/providers/amazon/aws/operators/test_glue_crawler.py b/tests/providers/amazon/aws/operators/test_glue_crawler.py index a7a38e78bb4a7..1e5a3f2177b5b 100644 --- a/tests/providers/amazon/aws/operators/test_glue_crawler.py +++ b/tests/providers/amazon/aws/operators/test_glue_crawler.py @@ -25,6 +25,7 @@ from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook from airflow.providers.amazon.aws.hooks.sts import StsHook from airflow.providers.amazon.aws.operators.glue_crawler import GlueCrawlerOperator +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection @@ -173,3 +174,6 @@ def test_crawler_wait_combinations(self, _, wait_for_completion, deferrable, moc assert response == mock_crawler_name assert crawler_hook.get_waiter.call_count == wait_for_completion assert self.op.defer.call_count == deferrable + + def test_template_fields(self): + validate_template_fields(self.op) diff --git a/tests/providers/amazon/aws/operators/test_glue_databrew.py b/tests/providers/amazon/aws/operators/test_glue_databrew.py index 0e88d477f3380..a18c6ddd4a41a 100644 --- a/tests/providers/amazon/aws/operators/test_glue_databrew.py +++ b/tests/providers/amazon/aws/operators/test_glue_databrew.py @@ -26,6 +26,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.glue_databrew import GlueDataBrewHook from airflow.providers.amazon.aws.operators.glue_databrew import GlueDataBrewStartJobOperator +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields JOB_NAME = "test_job" @@ -101,3 +102,7 @@ def test_start_job_with_deprecation_parameters(self, mock_hook_get_waiter, mock_ assert operator.waiter_delay == 15 operator.execute(None) mock_hook_get_waiter.assert_not_called() + + def test_template_fields(self): + operator = GlueDataBrewStartJobOperator(task_id="fake_task_id", job_name=JOB_NAME) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_kinesis_analytics.py b/tests/providers/amazon/aws/operators/test_kinesis_analytics.py index 36734d4f56b22..ab8bb3123008d 100644 --- a/tests/providers/amazon/aws/operators/test_kinesis_analytics.py +++ b/tests/providers/amazon/aws/operators/test_kinesis_analytics.py @@ -30,6 +30,7 @@ KinesisAnalyticsV2StartApplicationOperator, KinesisAnalyticsV2StopApplicationOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection @@ -159,6 +160,15 @@ def test_create_application_throw_error_when_invalid_arguments_provided( with pytest.raises(AirflowException, match=error_message): operator.execute({}) + def test_template_fields(self): + operator = KinesisAnalyticsV2CreateApplicationOperator( + task_id="create_application_operator", + application_name="demo", + runtime_environment="FLINK_18_9", + service_execution_role="arn", + ) + validate_template_fields(operator) + class TestKinesisAnalyticsV2StartApplicationOperator: APPLICATION_ARN = "arn:aws:kinesisanalytics:us-east-1:123456789012:application/demo" @@ -327,6 +337,9 @@ def test_execute_complete_failure(self, kinesis_analytics_mock_conn): ): self.operator.execute_complete(context=None, event=event) + def test_template_fields(self): + validate_template_fields(self.operator) + class TestKinesisAnalyticsV2StopApplicationOperator: APPLICATION_ARN = "arn:aws:kinesisanalytics:us-east-1:123456789012:application/demo" @@ -483,3 +496,6 @@ def test_execute_complete_failure(self, kinesis_analytics_mock_conn): AirflowException, match="Error while stopping AWS Managed Service for Apache Flink application" ): self.operator.execute_complete(context=None, event=event) + + def test_template_fields(self): + validate_template_fields(self.operator) diff --git a/tests/providers/amazon/aws/operators/test_lambda_function.py b/tests/providers/amazon/aws/operators/test_lambda_function.py index 9c977ee883ed7..e3a5b8cad6201 100644 --- a/tests/providers/amazon/aws/operators/test_lambda_function.py +++ b/tests/providers/amazon/aws/operators/test_lambda_function.py @@ -29,6 +29,7 @@ LambdaCreateFunctionOperator, LambdaInvokeFunctionOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields FUNCTION_NAME = "function_name" PAYLOADS = [ @@ -160,6 +161,17 @@ def test_create_lambda_using_config_argument(self, mock_hook_conn, mock_hook_cre assert operator.config.get("snap_start") == config.get("snap_start") assert operator.config.get("ephemeral_storage") == config.get("ephemeral_storage") + def test_template_fields(self): + operator = LambdaCreateFunctionOperator( + task_id="task_test", + function_name=FUNCTION_NAME, + role=ROLE_ARN, + code={ + "ImageUri": IMAGE_URI, + }, + ) + validate_template_fields(operator) + class TestLambdaInvokeFunctionOperator: @pytest.mark.parametrize("payload", PAYLOADS) @@ -280,3 +292,10 @@ def test_invoke_lambda_function_error(self, hook_mock): with pytest.raises(ValueError): operator.execute(None) + + def test_template_fields(self): + operator = LambdaInvokeFunctionOperator( + task_id="task_test", + function_name="a", + ) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_neptune.py b/tests/providers/amazon/aws/operators/test_neptune.py index ffd6da04adbda..146effaaa20f7 100644 --- a/tests/providers/amazon/aws/operators/test_neptune.py +++ b/tests/providers/amazon/aws/operators/test_neptune.py @@ -30,6 +30,7 @@ NeptuneStartDbClusterOperator, NeptuneStopDbClusterOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields CLUSTER_ID = "test_cluster" @@ -201,6 +202,16 @@ def test_start_cluster_instances_not_ready_defer(self, mock_conn, mock_defer): # mock_defer.assert_has_calls(calls) assert mock_defer.call_count == 2 + def test_template_fields(self): + operator = NeptuneStartDbClusterOperator( + task_id="task_test", + db_cluster_id=CLUSTER_ID, + deferrable=True, + wait_for_completion=False, + aws_conn_id="aws_default", + ) + validate_template_fields(operator) + class TestNeptuneStopClusterOperator: @mock.patch.object(NeptuneHook, "conn") @@ -368,3 +379,13 @@ def test_stop_cluster_deferrable(self, mock_conn): with pytest.raises(TaskDeferred): operator.execute(None) + + def test_template_fields(self): + operator = NeptuneStopDbClusterOperator( + task_id="task_test", + db_cluster_id=CLUSTER_ID, + deferrable=True, + wait_for_completion=False, + aws_conn_id="aws_default", + ) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_quicksight.py b/tests/providers/amazon/aws/operators/test_quicksight.py index fd304262939ac..f2d23c7b81793 100644 --- a/tests/providers/amazon/aws/operators/test_quicksight.py +++ b/tests/providers/amazon/aws/operators/test_quicksight.py @@ -21,6 +21,7 @@ from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook from airflow.providers.amazon.aws.operators.quicksight import QuickSightCreateIngestionOperator +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields DATA_SET_ID = "DemoDataSet" INGESTION_ID = "DemoDataSet_Ingestion" @@ -80,3 +81,7 @@ def test_execute(self, mock_create_ingestion): wait_for_completion=True, check_interval=30, ) + + def test_template_fields(self): + operator = QuickSightCreateIngestionOperator(**self.default_op_kwargs) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_rds.py b/tests/providers/amazon/aws/operators/test_rds.py index b8eeb09963a96..0574d4b553b48 100644 --- a/tests/providers/amazon/aws/operators/test_rds.py +++ b/tests/providers/amazon/aws/operators/test_rds.py @@ -44,6 +44,7 @@ ) from airflow.providers.amazon.aws.triggers.rds import RdsDbAvailableTrigger, RdsDbStoppedTrigger from airflow.utils import timezone +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook @@ -265,6 +266,16 @@ def test_create_db_cluster_snapshot_no_wait(self, mock_wait): assert len(cluster_snapshots) == 1 mock_wait.assert_not_called() + def test_template_fields(self): + operator = RdsCreateDbSnapshotOperator( + task_id="test_instance_", + db_type="instance", + db_snapshot_identifier=DB_INSTANCE_SNAPSHOT, + db_identifier=DB_INSTANCE_NAME, + aws_conn_id=AWS_CONN, + ) + validate_template_fields(operator) + class TestRdsCopyDbSnapshotOperator: @classmethod @@ -375,6 +386,16 @@ def test_copy_db_cluster_snapshot_no_wait(self, mock_await_status): assert len(cluster_snapshots) == 1 mock_await_status.assert_not_called() + def test_template_fields(self): + operator = RdsCopyDbSnapshotOperator( + task_id="test_cluster_no_wait", + db_type="cluster", + source_db_snapshot_identifier=DB_CLUSTER_SNAPSHOT, + target_db_snapshot_identifier=DB_CLUSTER_SNAPSHOT_COPY, + aws_conn_id=AWS_CONN, + ) + validate_template_fields(operator) + class TestRdsDeleteDbSnapshotOperator: @classmethod @@ -482,6 +503,16 @@ def test_delete_db_cluster_snapshot_no_wait(self): with pytest.raises(self.hook.conn.exceptions.ClientError): self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT) + def test_template_fields(self): + operator = RdsDeleteDbSnapshotOperator( + task_id="test_delete_db_cluster_snapshot_no_wait", + db_type="cluster", + db_snapshot_identifier=DB_CLUSTER_SNAPSHOT, + aws_conn_id=AWS_CONN, + wait_for_completion=False, + ) + validate_template_fields(operator) + class TestRdsStartExportTaskOperator: @classmethod @@ -552,6 +583,19 @@ def test_start_export_task_no_wait(self, mock_await_status): assert export_tasks[0]["Status"] == "complete" mock_await_status.assert_not_called() + def test_template_fields(self): + operator = RdsStartExportTaskOperator( + task_id="test_start_no_wait", + export_task_identifier=EXPORT_TASK_NAME, + source_arn=EXPORT_TASK_SOURCE, + iam_role_arn=EXPORT_TASK_ROLE_ARN, + kms_key_id=EXPORT_TASK_KMS, + s3_bucket_name=EXPORT_TASK_BUCKET, + aws_conn_id=AWS_CONN, + wait_for_completion=False, + ) + validate_template_fields(operator) + class TestRdsCancelExportTaskOperator: @classmethod @@ -616,6 +660,14 @@ def test_cancel_export_task_no_wait(self, mock_await_status): assert export_tasks[0]["Status"] == "canceled" mock_await_status.assert_not_called() + def test_template_fields(self): + operator = RdsCancelExportTaskOperator( + task_id="test_cancel", + export_task_identifier=EXPORT_TASK_NAME, + aws_conn_id=AWS_CONN, + ) + validate_template_fields(operator) + class TestRdsCreateEventSubscriptionOperator: @classmethod @@ -682,6 +734,17 @@ def test_create_event_subscription_no_wait(self, mock_await_status): assert subscriptions[0]["Status"] == "active" mock_await_status.assert_not_called() + def test_template_fields(self): + operator = RdsCreateEventSubscriptionOperator( + task_id="test_create", + subscription_name=SUBSCRIPTION_NAME, + sns_topic_arn=SUBSCRIPTION_TOPIC, + source_type="db-instance", + source_ids=[DB_INSTANCE_NAME], + aws_conn_id=AWS_CONN, + ) + validate_template_fields(operator) + class TestRdsDeleteEventSubscriptionOperator: @classmethod @@ -715,6 +778,14 @@ def test_delete_event_subscription(self): with pytest.raises(self.hook.conn.exceptions.ClientError): self.hook.conn.describe_event_subscriptions(SubscriptionName=EXPORT_TASK_NAME) + def test_template_fields(self): + operator = RdsDeleteEventSubscriptionOperator( + task_id="test_delete", + subscription_name=SUBSCRIPTION_NAME, + aws_conn_id=AWS_CONN, + ) + validate_template_fields(operator) + class TestRdsCreateDbInstanceOperator: @classmethod @@ -781,6 +852,19 @@ def test_create_db_instance_no_wait(self, mock_await_status): assert db_instances[0]["DBInstanceStatus"] == "available" mock_await_status.assert_not_called() + def test_template_fields(self): + operator = RdsCreateDbInstanceOperator( + task_id="test_create_db_instance", + db_instance_identifier=DB_INSTANCE_NAME, + db_instance_class="db.m5.large", + engine="postgres", + rds_kwargs={ + "DBName": DB_INSTANCE_NAME, + }, + aws_conn_id=AWS_CONN, + ) + validate_template_fields(operator) + class TestRdsDeleteDbInstanceOperator: @classmethod @@ -839,6 +923,18 @@ def test_delete_db_instance_no_wait(self, mock_await_status): self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME) mock_await_status.assert_not_called() + def test_template_fields(self): + operator = RdsDeleteDbInstanceOperator( + task_id="test_delete_db_instance_no_wait", + db_instance_identifier=DB_INSTANCE_NAME, + rds_kwargs={ + "SkipFinalSnapshot": True, + }, + aws_conn_id=AWS_CONN, + wait_for_completion=False, + ) + validate_template_fields(operator) + class TestRdsStopDbOperator: @classmethod @@ -943,6 +1039,15 @@ def test_stop_db_cluster_create_snapshot_logs_warning_message(self, caplog): ) assert warning_message in caplog.text + def test_template_fields(self): + operator = RdsStopDbOperator( + task_id="test_stop_db_cluster", + db_identifier=DB_CLUSTER_NAME, + db_type="cluster", + db_snapshot_identifier=DB_CLUSTER_SNAPSHOT, + ) + validate_template_fields(operator) + class TestRdsStartDbOperator: @classmethod @@ -1008,3 +1113,9 @@ def test_deferred(self, conn_mock): op.execute({}) assert isinstance(defer.value.trigger, RdsDbAvailableTrigger) + + def test_template_fields(self): + operator = RdsStartDbOperator( + task_id="test_start_db_cluster", db_identifier=DB_CLUSTER_NAME, db_type="cluster" + ) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py b/tests/providers/amazon/aws/operators/test_redshift_cluster.py index f6c960b7bd133..e48f7b2ed96ea 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py @@ -38,6 +38,7 @@ RedshiftPauseClusterTrigger, RedshiftResumeClusterTrigger, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields class TestRedshiftCreateClusterOperator: @@ -137,6 +138,18 @@ def test_create_cluster_deferrable(self, mock_get_conn): with pytest.raises(TaskDeferred): redshift_operator.execute(None) + def test_template_fields(self): + operator = RedshiftCreateClusterOperator( + task_id="task_test", + cluster_identifier="test-cluster", + node_type="dc2.large", + master_username="adminuser", + master_user_password="Test123$", + cluster_type="single-node", + deferrable=True, + ) + validate_template_fields(operator) + class TestRedshiftCreateClusterSnapshotOperator: @mock.patch.object(RedshiftHook, "cluster_status") @@ -214,6 +227,15 @@ def test_create_cluster_snapshot_deferred(self, mock_create_cluster_snapshot, mo exc.value.trigger, RedshiftCreateClusterSnapshotTrigger ), "Trigger is not a RedshiftCreateClusterSnapshotTrigger" + def test_template_fields(self): + operator = RedshiftCreateClusterSnapshotOperator( + task_id="test_snapshot", + cluster_identifier="test_cluster", + snapshot_identifier="test_snapshot", + wait_for_completion=True, + ) + validate_template_fields(operator) + class TestRedshiftDeleteClusterSnapshotOperator: @mock.patch( @@ -256,6 +278,15 @@ def test_delete_cluster_snapshot(self, mock_get_conn, mock_get_cluster_snapshot_ mock_get_cluster_snapshot_status.assert_not_called() + def test_template_fields(self): + operator = RedshiftDeleteClusterSnapshotOperator( + task_id="test_snapshot", + cluster_identifier="test_cluster", + snapshot_identifier="test_snapshot", + wait_for_completion=False, + ) + validate_template_fields(operator) + class TestResumeClusterOperator: def test_init(self): @@ -386,6 +417,14 @@ def test_resume_cluster_failure(self): context=None, event={"status": "error", "message": "test failure message"} ) + def test_template_fields(self): + operator = RedshiftResumeClusterOperator( + task_id="task_test", + cluster_identifier="test_cluster", + aws_conn_id="aws_conn_test", + ) + validate_template_fields(operator) + class TestPauseClusterOperator: def test_init(self): @@ -511,6 +550,13 @@ def test_pause_cluster_execute_complete_fail(self): context=None, event={"status": "error", "message": "test failure message"} ) + def test_template_fields(self): + operator = RedshiftPauseClusterOperator( + task_id="task_test", + cluster_identifier="test_cluster", + ) + validate_template_fields(operator) + class TestDeleteClusterOperator: @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status") @@ -648,3 +694,10 @@ def test_delete_cluster_execute_complete_fail(self): redshift_operator.execute_complete( context=None, event={"status": "error", "message": "test failure message"} ) + + def test_template_fields(self): + operator = RedshiftDeleteClusterOperator( + task_id="task_test", + cluster_identifier="test_cluster", + ) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py index fa021395a419d..abfa2b038b98b 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_data.py +++ b/tests/providers/amazon/aws/operators/test_redshift_data.py @@ -24,6 +24,7 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator from airflow.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields CONN_ID = "aws_conn_test" TASK_ID = "task_id" @@ -349,3 +350,14 @@ def test_no_wait_for_completion(self, mock_exec_query, mock_check_query_is_finis assert not mock_check_query_is_finished.called assert not mock_defer.called + + def test_template_fields(self): + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + cluster_identifier="cluster_identifier", + sql=SQL, + database=DATABASE, + wait_for_completion=False, + ) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_s3.py b/tests/providers/amazon/aws/operators/test_s3.py index a7a38cabe81bd..267b678c8dd89 100644 --- a/tests/providers/amazon/aws/operators/test_s3.py +++ b/tests/providers/amazon/aws/operators/test_s3.py @@ -52,6 +52,7 @@ ) from airflow.providers.openlineage.extractors import OperatorLineage from airflow.utils.timezone import datetime, utcnow +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields BUCKET_NAME = os.environ.get("BUCKET_NAME", "test-airflow-bucket") S3_KEY = "test-airflow-key" @@ -85,6 +86,9 @@ def test_execute_if_not_bucket_exist(self, mock_check_for_bucket, mock_create_bu mock_check_for_bucket.assert_called_once_with(BUCKET_NAME) mock_create_bucket.assert_called_once_with(bucket_name=BUCKET_NAME, region_name=None) + def test_template_fields(self): + validate_template_fields(self.create_bucket_operator) + class TestS3DeleteBucketOperator: def setup_method(self): @@ -113,6 +117,9 @@ def test_execute_if_not_bucket_exist(self, mock_check_for_bucket, mock_delete_bu mock_check_for_bucket.assert_called_once_with(BUCKET_NAME) mock_delete_bucket.assert_not_called() + def test_template_fields(self): + validate_template_fields(self.delete_bucket_operator) + class TestS3GetBucketTaggingOperator: def setup_method(self): @@ -141,6 +148,9 @@ def test_execute_if_not_bucket_exist(self, mock_check_for_bucket, get_bucket_tag mock_check_for_bucket.assert_called_once_with(BUCKET_NAME) get_bucket_tagging.assert_not_called() + def test_template_fields(self): + validate_template_fields(self.get_bucket_tagging_operator) + class TestS3PutBucketTaggingOperator: def setup_method(self): @@ -172,6 +182,9 @@ def test_execute_if_not_bucket_exist(self, mock_check_for_bucket, put_bucket_tag mock_check_for_bucket.assert_called_once_with(BUCKET_NAME) put_bucket_tagging.assert_not_called() + def test_template_fields(self): + validate_template_fields(self.put_bucket_tagging_operator) + class TestS3DeleteBucketTaggingOperator: def setup_method(self): @@ -200,6 +213,9 @@ def test_execute_if_not_bucket_exist(self, mock_check_for_bucket, delete_bucket_ mock_check_for_bucket.assert_called_once_with(BUCKET_NAME) delete_bucket_tagging.assert_not_called() + def test_template_fields(self): + validate_template_fields(self.delete_bucket_tagging_operator) + class TestS3FileTransformOperator: def setup_method(self): @@ -381,6 +397,16 @@ def s3_paths(self): return input_path, output_path + def test_template_fields(self): + operator = S3FileTransformOperator( + source_s3_key="test/key", + dest_s3_key="test/key", + transform_script=self.transform_script, + replace=True, + task_id="task_id", + ) + validate_template_fields(operator) + class TestS3ListOperator: @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook") @@ -404,6 +430,15 @@ def test_execute(self, mock_hook): ) assert sorted(files) == sorted(["TEST1.csv", "TEST2.csv", "TEST3.csv"]) + def test_template_fields(self): + operator = S3ListOperator( + task_id="test-s3-list-operator", + bucket=BUCKET_NAME, + prefix="TEST", + delimiter=".csv", + ) + validate_template_fields(operator) + class TestS3ListPrefixesOperator: @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook") @@ -421,6 +456,12 @@ def test_execute(self, mock_hook): ) assert subfolders == ["test/"] + def test_template_fields(self): + operator = S3ListPrefixesOperator( + task_id="test-s3-list-prefixes-operator", bucket=BUCKET_NAME, prefix="test/", delimiter="/" + ) + validate_template_fields(operator) + class TestS3CopyObjectOperator: def setup_method(self): @@ -528,6 +569,16 @@ def test_get_openlineage_facets_on_start_combination_2(self): assert lineage.inputs[0] == expected_input assert lineage.outputs[0] == expected_output + def test_template_fields(self): + operator = S3CopyObjectOperator( + task_id="test_task_s3_copy_object", + source_bucket_key=self.source_key, + source_bucket_name=self.source_bucket, + dest_bucket_key=self.dest_key, + dest_bucket_name=self.dest_bucket, + ) + validate_template_fields(operator) + @mock_aws class TestS3DeleteObjectsOperator: @@ -839,6 +890,12 @@ def test_get_openlineage_facets_on_complete_no_objects(self, mock_hook, keys): lineage = op.get_openlineage_facets_on_complete(None) assert lineage == OperatorLineage() + def test_template_fields(self): + operator = S3DeleteObjectsOperator( + task_id="test_task_s3_delete_single_object", bucket="test-bucket", keys="test/file.csv" + ) + validate_template_fields(operator) + class TestS3CreateObjectOperator: @mock.patch.object(S3Hook, "load_string") @@ -892,3 +949,7 @@ def test_get_openlineage_facets_on_start(self, bucket, key): assert len(lineage.inputs) == 0 assert len(lineage.outputs) == 1 assert lineage.outputs[0] == expected_output + + def test_template_fields(self): + operator = S3CreateObjectOperator(task_id="test", s3_bucket="bucket", s3_key="key", data="test") + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_base.py b/tests/providers/amazon/aws/operators/test_sagemaker_base.py index 25e7ff9c2a443..5de40708d5158 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_base.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_base.py @@ -32,6 +32,7 @@ ) from airflow.utils import timezone from airflow.utils.types import DagRunType +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields CONFIG: dict = { "key1": "1", @@ -195,3 +196,11 @@ def test_create_experiment(self, conn_mock, session, clean_dags_and_dagruns): Description="the desc", Tags=[{"Key": "jinja", "Value": "tid"}], ) + + def test_template_fields(self): + operator = SageMakerCreateExperimentOperator( + name="the name", + description="the desc", + task_id="tid", + ) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py index d31556f9bf9a8..24cf944f8db6a 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py @@ -27,6 +27,7 @@ from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointOperator from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields CREATE_MODEL_PARAMS: dict = { "ModelName": "model_name", @@ -173,3 +174,6 @@ def test_deferred(self, mock_create_endpoint, _, __): assert isinstance(defer.value.trigger, SageMakerTrigger) assert defer.value.trigger.job_name == "endpoint_name" assert defer.value.trigger.job_type == "endpoint" + + def test_template_fields(self): + validate_template_fields(self.sagemaker) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py index 29a488140f718..1169f09d9141d 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py @@ -25,6 +25,7 @@ from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointConfigOperator +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields CREATE_ENDPOINT_CONFIG_PARAMS: dict = { "EndpointConfigName": "config_name", @@ -81,3 +82,6 @@ def test_execute_with_failure(self, mock_model, mock_client): } with pytest.raises(AirflowException): self.sagemaker.execute(None) + + def test_template_fields(self): + validate_template_fields(self.sagemaker) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_model.py b/tests/providers/amazon/aws/operators/test_sagemaker_model.py index c61585d5b8ca0..33d1f5b4d1f6f 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_model.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_model.py @@ -31,6 +31,7 @@ SageMakerModelOperator, SageMakerRegisterModelVersionOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields CREATE_MODEL_PARAMS: dict = { "ModelName": "model_name", @@ -74,6 +75,12 @@ def test_execute(self, delete_model): op.execute(None) delete_model.assert_called_once_with(model_name="model_name") + def test_template_fields(self): + op = SageMakerDeleteModelOperator( + task_id="test_sagemaker_operator", config={"ModelName": "model_name"} + ) + validate_template_fields(op) + class TestSageMakerRegisterModelVersionOperator: @patch.object(SageMakerHook, "create_model_package_group") @@ -144,3 +151,14 @@ def test_can_override_parameters_using_extras(self, conn_mock, _): conn_mock().create_model_package.assert_called_once() args_dict = conn_mock().create_model_package.call_args.kwargs assert args_dict["InferenceSpecification"]["SupportedResponseMIMETypes"] == response_type + + def test_template_fields(self): + response_type = ["test/test"] + op = SageMakerRegisterModelVersionOperator( + task_id="test", + image_uri="257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost:1.2-1", + model_url="s3://your-bucket-name/model.tar.gz", + package_group_name="group-name", + extras={"InferenceSpecification": {"SupportedResponseMIMETypes": response_type}}, + ) + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_notebook.py b/tests/providers/amazon/aws/operators/test_sagemaker_notebook.py index cde4944440c4c..093e264a3e454 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_notebook.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_notebook.py @@ -30,6 +30,7 @@ SageMakerStartNoteBookOperator, SageMakerStopNotebookOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields INSTANCE_NAME = "notebook" INSTANCE_TYPE = "ml.t3.medium" @@ -105,6 +106,15 @@ def test_create_notebook_wait_for_completion(self, mock_hook_conn): mock_hook_conn.create_notebook_instance.assert_called_once() mock_hook_conn.get_waiter.assert_called_once_with("notebook_instance_in_service") + def test_template_fields(self): + operator = SageMakerCreateNotebookOperator( + task_id="task_test", + instance_name=INSTANCE_NAME, + instance_type=INSTANCE_TYPE, + role_arn=ROLE_ARN, + ) + validate_template_fields(operator) + class TestSageMakerStopNotebookOperator: @mock.patch.object(SageMakerHook, "conn") @@ -125,6 +135,12 @@ def test_stop_notebook_wait_for_completion(self, mock_hook_conn, hook): hook.conn.stop_notebook_instance.assert_called_once() mock_hook_conn.get_waiter.assert_called_once_with("notebook_instance_stopped") + def test_template_fields(self): + operator = SageMakerStopNotebookOperator( + task_id="stop_test", instance_name=INSTANCE_NAME, wait_for_completion=False + ) + validate_template_fields(operator) + class TestSageMakerDeleteNotebookOperator: @mock.patch.object(SageMakerHook, "conn") @@ -145,6 +161,12 @@ def test_delete_notebook_wait_for_completion(self, mock_hook_conn, hook): hook.conn.delete_notebook_instance.assert_called_once() mock_hook_conn.get_waiter.assert_called_once_with("notebook_instance_deleted") + def test_template_fields(self): + operator = SageMakerDeleteNotebookOperator( + task_id="delete_test", instance_name=INSTANCE_NAME, wait_for_completion=True + ) + validate_template_fields(operator) + class TestSageMakerStartNotebookOperator: @mock.patch.object(SageMakerHook, "conn") @@ -164,3 +186,9 @@ def test_start_notebook_wait_for_completion(self, mock_hook_conn, hook): operator.execute(None) hook.conn.start_notebook_instance.assert_called_once() mock_hook_conn.get_waiter.assert_called_once_with("notebook_instance_in_service") + + def test_template_fields(self): + operator = SageMakerStartNoteBookOperator( + task_id="start_test", instance_name=INSTANCE_NAME, wait_for_completion=True + ) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_pipeline.py b/tests/providers/amazon/aws/operators/test_sagemaker_pipeline.py index 7e7faa5902631..e7334de98df03 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_pipeline.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_pipeline.py @@ -29,6 +29,7 @@ SageMakerStopPipelineOperator, ) from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerPipelineTrigger +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields if TYPE_CHECKING: from unittest.mock import MagicMock @@ -71,6 +72,13 @@ def test_defer(self, start_mock): assert isinstance(defer.value.trigger, SageMakerPipelineTrigger) assert defer.value.trigger.waiter_type == SageMakerPipelineTrigger.Type.COMPLETE + def test_template_fields(self): + operator = SageMakerStartPipelineOperator( + task_id="test_sagemaker_operator", + pipeline_name="my_pipeline", + ) + validate_template_fields(operator) + class TestSageMakerStopPipelineOperator: @mock.patch.object(SageMakerHook, "stop_pipeline") @@ -100,3 +108,11 @@ def test_defer(self, stop_mock: MagicMock): assert isinstance(defer.value.trigger, SageMakerPipelineTrigger) assert defer.value.trigger.waiter_type == SageMakerPipelineTrigger.Type.STOPPED + + def test_template_fields(self): + operator = SageMakerStopPipelineOperator( + task_id="test_sagemaker_operator", + pipeline_exec_arn="my_pipeline_arn", + deferrable=True, + ) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py index 45e90e204a5ef..b1ca2b62adbb9 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py @@ -31,6 +31,7 @@ from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields CREATE_PROCESSING_PARAMS: dict = { "AppSpecification": { @@ -345,3 +346,10 @@ def test_operator_openlineage_data(self, check_job_exists, mock_processing, _, m inputs=[Dataset(namespace="s3://input-bucket", name="input-path")], outputs=[Dataset(namespace="s3://output-bucket", name="output-path")], ) + + def test_template_fields(self): + operator = SageMakerProcessingOperator( + **self.processing_config_kwargs, + config=CREATE_PROCESSING_PARAMS, + ) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py index 4426b4f152371..85c6954ac1ac5 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py @@ -31,6 +31,7 @@ ) from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields EXPECTED_INTEGER_FIELDS: list[list[str]] = [ ["ResourceConfig", "InstanceCount"], @@ -235,3 +236,6 @@ def test_execute_openlineage_data(self, mock_exists, mock_training, mock_desc): inputs=[Dataset(namespace="s3://input-bucket", name="input-path")], outputs=[Dataset(namespace="s3://model-bucket", name="model-path")], ) + + def test_template_fields(self): + validate_template_fields(self.sagemaker) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py index 314d0ba46a525..7804ddde20400 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py @@ -30,6 +30,7 @@ from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields EXPECTED_INTEGER_FIELDS: list[list[str]] = [ ["Transform", "TransformResources", "InstanceCount"], @@ -377,3 +378,6 @@ def test_operator_lineage_data(self, mock_transform, mock_conn, mock_model, _, m ], outputs=[Dataset(namespace="s3://output-bucket", name="output-path")], ) + + def test_template_fields(self): + validate_template_fields(self.sagemaker) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py index 964decd524d2e..78058c771a28d 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py @@ -26,6 +26,7 @@ from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields EXPECTED_INTEGER_FIELDS: list[list[str]] = [ ["HyperParameterTuningJobConfig", "ResourceLimits", "MaxNumberOfTrainingJobs"], @@ -120,3 +121,6 @@ def test_defers(self, create_mock): assert isinstance(defer.value.trigger, SageMakerTrigger) assert defer.value.trigger.job_name == "job_name" assert defer.value.trigger.job_type == "tuning" + + def test_template_fields(self): + validate_template_fields(self.sagemaker) diff --git a/tests/providers/amazon/aws/operators/test_sns.py b/tests/providers/amazon/aws/operators/test_sns.py index 780bc7eade3ea..6c5de06822d3a 100644 --- a/tests/providers/amazon/aws/operators/test_sns.py +++ b/tests/providers/amazon/aws/operators/test_sns.py @@ -22,6 +22,7 @@ import pytest from airflow.providers.amazon.aws.operators.sns import SnsPublishOperator +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields TASK_ID = "sns_publish_job" AWS_CONN_ID = "custom_aws_conn" @@ -76,3 +77,7 @@ def test_execute(self, mocked_hook): subject=SUBJECT, target_arn=TARGET_ARN, ) + + def test_template_fields(self): + operator = SnsPublishOperator(**self.default_op_kwargs) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_sqs.py b/tests/providers/amazon/aws/operators/test_sqs.py index 1f313daeda23b..2187262fe7c36 100644 --- a/tests/providers/amazon/aws/operators/test_sqs.py +++ b/tests/providers/amazon/aws/operators/test_sqs.py @@ -25,6 +25,7 @@ from airflow.providers.amazon.aws.hooks.sqs import SqsHook from airflow.providers.amazon.aws.operators.sqs import SqsPublishOperator +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields REGION_NAME = "eu-west-1" QUEUE_NAME = "test-queue" @@ -119,3 +120,9 @@ def test_execute_success_fifo_queue(self, mocked_context): assert message["Messages"][0]["MessageId"] == result["MessageId"] assert message["Messages"][0]["Body"] == "hello" assert message["Messages"][0]["Attributes"]["MessageGroupId"] == "abc" + + def test_template_fields(self): + operator = SqsPublishOperator( + **self.default_op_kwargs, sqs_queue=FIFO_QUEUE_NAME, message_group_id="abc" + ) + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_step_function.py b/tests/providers/amazon/aws/operators/test_step_function.py index e8ab5c85a6b8f..29d743996af48 100644 --- a/tests/providers/amazon/aws/operators/test_step_function.py +++ b/tests/providers/amazon/aws/operators/test_step_function.py @@ -26,6 +26,7 @@ StepFunctionGetExecutionOutputOperator, StepFunctionStartExecutionOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields EXECUTION_ARN = ( "arn:aws:states:us-east-1:123456789012:execution:" @@ -102,6 +103,14 @@ def test_execute(self, mocked_hook, mocked_context, response, expected_output): execution_arn=EXECUTION_ARN, ) + def test_template_fields(self): + operator = StepFunctionGetExecutionOutputOperator( + task_id=self.TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=None, + ) + validate_template_fields(operator) + class TestStepFunctionStartExecutionOperator: TASK_ID = "step_function_start_execution_task" @@ -232,3 +241,14 @@ def test_start_redrive_execution(self, mocked_hook, mocked_context): region_name=mock.ANY, execution_arn=EXECUTION_ARN, ) + + def test_template_fields(self): + operator = StepFunctionStartExecutionOperator( + task_id=self.TASK_ID, + state_machine_arn=STATE_MACHINE_ARN, + name=NAME, + is_redrive_execution=True, + state_machine_input=None, + aws_conn_id=None, + ) + validate_template_fields(operator)