Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions tests/providers/amazon/aws/operators/test_eventbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
EventBridgePutEventsOperator,
EventBridgePutRuleOperator,
)
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

if TYPE_CHECKING:
from unittest.mock import MagicMock
Expand Down Expand Up @@ -96,6 +97,13 @@ def test_failed_to_send(self, mock_conn: MagicMock):
with pytest.raises(AirflowException):
operator.execute(context={})

def test_template_fields(self):
operator = EventBridgePutEventsOperator(
task_id="failed_put_events_job",
entries=ENTRIES,
)
validate_template_fields(operator)


class TestEventBridgePutRuleOperator:
def test_init(self):
Expand Down Expand Up @@ -150,6 +158,12 @@ def test_put_rule_with_bad_json_fails(self):
with pytest.raises(ValueError):
operator.execute(None)

def test_template_fields(self):
operator = EventBridgePutRuleOperator(
task_id="events_put_rule_job", name=RULE_NAME, event_pattern=EVENT_PATTERN
)
validate_template_fields(operator)


class TestEventBridgeEnableRuleOperator:
def test_init(self):
Expand Down Expand Up @@ -186,6 +200,13 @@ def test_enable_rule(self, mock_conn: MagicMock):
enable_rule.execute(context={})
mock_conn.enable_rule.assert_called_with(Name=RULE_NAME)

def test_template_fields(self):
operator = EventBridgeEnableRuleOperator(
task_id="events_enable_rule_job",
name=RULE_NAME,
)
validate_template_fields(operator)


class TestEventBridgeDisableRuleOperator:
def test_init(self):
Expand Down Expand Up @@ -221,3 +242,10 @@ def test_disable_rule(self, mock_conn: MagicMock):

disable_rule.execute(context={})
mock_conn.disable_rule.assert_called_with(Name=RULE_NAME)

def test_template_fields(self):
operator = EventBridgeDisableRuleOperator(
task_id="events_disable_rule_job",
name=RULE_NAME,
)
validate_template_fields(operator)
9 changes: 9 additions & 0 deletions tests/providers/amazon/aws/operators/test_glacier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
GlacierCreateJobOperator,
GlacierUploadArchiveOperator,
)
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

if TYPE_CHECKING:
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
Expand Down Expand Up @@ -78,6 +79,10 @@ def test_execute(self, hook_mock):
op.execute(mock.MagicMock())
hook_mock.return_value.retrieve_inventory.assert_called_once_with(vault_name=VAULT_NAME)

def test_template_fields(self):
operator = self.op_class(**self.default_op_kwargs)
validate_template_fields(operator)


class TestGlacierUploadArchiveOperator(BaseGlacierOperatorsTests):
op_class = GlacierUploadArchiveOperator
Expand All @@ -97,3 +102,7 @@ def test_execute(self):
body=b"Test Data",
checksum=None,
)

def test_template_fields(self):
operator = self.op_class(**self.default_op_kwargs)
validate_template_fields(operator)
24 changes: 24 additions & 0 deletions tests/providers/amazon/aws/operators/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
GlueDataQualityRuleSetEvaluationRunOperator,
GlueJobOperator,
)
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

if TYPE_CHECKING:
from airflow.models import TaskInstance
Expand Down Expand Up @@ -307,6 +308,17 @@ def test_replace_script_file(
"folder/file", "artifacts/glue-scripts/file", bucket_name="bucket_name", replace=True
)

def test_template_fields(self):
operator = GlueJobOperator(
task_id=TASK_ID,
job_name=JOB_NAME,
script_location="folder/file",
s3_bucket="bucket_name",
iam_role_name="role_arn",
replace_script_file=True,
)
validate_template_fields(operator)


class TestGlueDataQualityOperator:
RULE_SET_NAME = "TestRuleSet"
Expand Down Expand Up @@ -435,6 +447,12 @@ def test_validate_inputs_error(self):
with pytest.raises(AttributeError, match="RuleSet must starts with Rules = \\[ and ends with \\]"):
self.operator.validate_inputs()

def test_template_fields(self):
operator = GlueDataQualityOperator(
task_id="create_data_quality_ruleset", name=self.RULE_SET_NAME, ruleset=self.RULE_SET
)
validate_template_fields(operator)


class TestGlueDataQualityRuleSetEvaluationRunOperator:
RUN_ID = "1234567890"
Expand Down Expand Up @@ -538,6 +556,9 @@ def test_start_data_quality_ruleset_evaluation_run_wait_combinations(
assert glue_data_quality_hook.get_waiter.call_count == wait_for_completion
assert self.operator.defer.call_count == deferrable

def test_template_fields(self):
validate_template_fields(self.operator)


class TestGlueDataQualityRuleRecommendationRunOperator:
RUN_ID = "1234567890"
Expand Down Expand Up @@ -643,3 +664,6 @@ def test_start_data_quality_rule_recommendation_run_wait_combinations(
assert response == self.RUN_ID
assert glue_data_quality_hook.get_waiter.call_count == wait_for_completion
assert self.operator.defer.call_count == deferrable

def test_template_fields(self):
validate_template_fields(self.operator)
4 changes: 4 additions & 0 deletions tests/providers/amazon/aws/operators/test_glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
from airflow.providers.amazon.aws.hooks.sts import StsHook
from airflow.providers.amazon.aws.operators.glue_crawler import GlueCrawlerOperator
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
Expand Down Expand Up @@ -173,3 +174,6 @@ def test_crawler_wait_combinations(self, _, wait_for_completion, deferrable, moc
assert response == mock_crawler_name
assert crawler_hook.get_waiter.call_count == wait_for_completion
assert self.op.defer.call_count == deferrable

def test_template_fields(self):
validate_template_fields(self.op)
5 changes: 5 additions & 0 deletions tests/providers/amazon/aws/operators/test_glue_databrew.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.glue_databrew import GlueDataBrewHook
from airflow.providers.amazon.aws.operators.glue_databrew import GlueDataBrewStartJobOperator
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

JOB_NAME = "test_job"

Expand Down Expand Up @@ -101,3 +102,7 @@ def test_start_job_with_deprecation_parameters(self, mock_hook_get_waiter, mock_
assert operator.waiter_delay == 15
operator.execute(None)
mock_hook_get_waiter.assert_not_called()

def test_template_fields(self):
operator = GlueDataBrewStartJobOperator(task_id="fake_task_id", job_name=JOB_NAME)
validate_template_fields(operator)
16 changes: 16 additions & 0 deletions tests/providers/amazon/aws/operators/test_kinesis_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
KinesisAnalyticsV2StartApplicationOperator,
KinesisAnalyticsV2StopApplicationOperator,
)
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
Expand Down Expand Up @@ -159,6 +160,15 @@ def test_create_application_throw_error_when_invalid_arguments_provided(
with pytest.raises(AirflowException, match=error_message):
operator.execute({})

def test_template_fields(self):
operator = KinesisAnalyticsV2CreateApplicationOperator(
task_id="create_application_operator",
application_name="demo",
runtime_environment="FLINK_18_9",
service_execution_role="arn",
)
validate_template_fields(operator)


class TestKinesisAnalyticsV2StartApplicationOperator:
APPLICATION_ARN = "arn:aws:kinesisanalytics:us-east-1:123456789012:application/demo"
Expand Down Expand Up @@ -327,6 +337,9 @@ def test_execute_complete_failure(self, kinesis_analytics_mock_conn):
):
self.operator.execute_complete(context=None, event=event)

def test_template_fields(self):
validate_template_fields(self.operator)


class TestKinesisAnalyticsV2StopApplicationOperator:
APPLICATION_ARN = "arn:aws:kinesisanalytics:us-east-1:123456789012:application/demo"
Expand Down Expand Up @@ -483,3 +496,6 @@ def test_execute_complete_failure(self, kinesis_analytics_mock_conn):
AirflowException, match="Error while stopping AWS Managed Service for Apache Flink application"
):
self.operator.execute_complete(context=None, event=event)

def test_template_fields(self):
validate_template_fields(self.operator)
19 changes: 19 additions & 0 deletions tests/providers/amazon/aws/operators/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
LambdaCreateFunctionOperator,
LambdaInvokeFunctionOperator,
)
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

FUNCTION_NAME = "function_name"
PAYLOADS = [
Expand Down Expand Up @@ -160,6 +161,17 @@ def test_create_lambda_using_config_argument(self, mock_hook_conn, mock_hook_cre
assert operator.config.get("snap_start") == config.get("snap_start")
assert operator.config.get("ephemeral_storage") == config.get("ephemeral_storage")

def test_template_fields(self):
operator = LambdaCreateFunctionOperator(
task_id="task_test",
function_name=FUNCTION_NAME,
role=ROLE_ARN,
code={
"ImageUri": IMAGE_URI,
},
)
validate_template_fields(operator)


class TestLambdaInvokeFunctionOperator:
@pytest.mark.parametrize("payload", PAYLOADS)
Expand Down Expand Up @@ -280,3 +292,10 @@ def test_invoke_lambda_function_error(self, hook_mock):

with pytest.raises(ValueError):
operator.execute(None)

def test_template_fields(self):
operator = LambdaInvokeFunctionOperator(
task_id="task_test",
function_name="a",
)
validate_template_fields(operator)
21 changes: 21 additions & 0 deletions tests/providers/amazon/aws/operators/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
NeptuneStartDbClusterOperator,
NeptuneStopDbClusterOperator,
)
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

CLUSTER_ID = "test_cluster"

Expand Down Expand Up @@ -201,6 +202,16 @@ def test_start_cluster_instances_not_ready_defer(self, mock_conn, mock_defer):
# mock_defer.assert_has_calls(calls)
assert mock_defer.call_count == 2

def test_template_fields(self):
operator = NeptuneStartDbClusterOperator(
task_id="task_test",
db_cluster_id=CLUSTER_ID,
deferrable=True,
wait_for_completion=False,
aws_conn_id="aws_default",
)
validate_template_fields(operator)


class TestNeptuneStopClusterOperator:
@mock.patch.object(NeptuneHook, "conn")
Expand Down Expand Up @@ -368,3 +379,13 @@ def test_stop_cluster_deferrable(self, mock_conn):

with pytest.raises(TaskDeferred):
operator.execute(None)

def test_template_fields(self):
operator = NeptuneStopDbClusterOperator(
task_id="task_test",
db_cluster_id=CLUSTER_ID,
deferrable=True,
wait_for_completion=False,
aws_conn_id="aws_default",
)
validate_template_fields(operator)
5 changes: 5 additions & 0 deletions tests/providers/amazon/aws/operators/test_quicksight.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook
from airflow.providers.amazon.aws.operators.quicksight import QuickSightCreateIngestionOperator
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

DATA_SET_ID = "DemoDataSet"
INGESTION_ID = "DemoDataSet_Ingestion"
Expand Down Expand Up @@ -80,3 +81,7 @@ def test_execute(self, mock_create_ingestion):
wait_for_completion=True,
check_interval=30,
)

def test_template_fields(self):
operator = QuickSightCreateIngestionOperator(**self.default_op_kwargs)
validate_template_fields(operator)
Loading