diff --git a/TESTING.rst b/TESTING.rst
index 30e6fff897a2e..1efc63c3ee34a 100644
--- a/TESTING.rst
+++ b/TESTING.rst
@@ -232,7 +232,7 @@ Example test here:
res = render_chart('GIT-SYNC', helm_settings,
show_only=["templates/scheduler/scheduler-deployment.yaml"])
dep: k8s.V1Deployment = render_k8s_object(res[0], k8s.V1Deployment)
- self.assertEqual("dags", dep.spec.template.spec.volumes[1].name)
+ assert "dags" == dep.spec.template.spec.volumes[1].name
To run tests using breeze run the following command
@@ -330,7 +330,7 @@ Example of the ``redis`` integration test:
hook = RedisHook(redis_conn_id='redis_default')
redis = hook.get_conn()
- self.assertTrue(redis.ping(), 'Connection to Redis with PING works.')
+ assert redis.ping(), 'Connection to Redis with PING works.'
The markers can be specified at the test level or the class level (then all tests in this class
require an integration). You can add multiple markers with different integrations for tests that
diff --git a/chart/tests/test_basic_helm_chart.py b/chart/tests/test_basic_helm_chart.py
index 7d835a6fa9cb5..f8526b040c2e0 100644
--- a/chart/tests/test_basic_helm_chart.py
+++ b/chart/tests/test_basic_helm_chart.py
@@ -39,52 +39,47 @@ def test_basic_deployments(self):
list_of_kind_names_tuples = [
(k8s_object['kind'], k8s_object['metadata']['name']) for k8s_object in k8s_objects
]
- self.assertEqual(
- list_of_kind_names_tuples,
- [
- ('ServiceAccount', 'TEST-BASIC-scheduler'),
- ('ServiceAccount', 'TEST-BASIC-webserver'),
- ('ServiceAccount', 'TEST-BASIC-worker'),
- ('Secret', 'TEST-BASIC-postgresql'),
- ('Secret', 'TEST-BASIC-airflow-metadata'),
- ('Secret', 'TEST-BASIC-airflow-result-backend'),
- ('ConfigMap', 'TEST-BASIC-airflow-config'),
- ('Role', 'TEST-BASIC-pod-launcher-role'),
- ('Role', 'TEST-BASIC-pod-log-reader-role'),
- ('RoleBinding', 'TEST-BASIC-pod-launcher-rolebinding'),
- ('RoleBinding', 'TEST-BASIC-pod-log-reader-rolebinding'),
- ('Service', 'TEST-BASIC-postgresql-headless'),
- ('Service', 'TEST-BASIC-postgresql'),
- ('Service', 'TEST-BASIC-statsd'),
- ('Service', 'TEST-BASIC-webserver'),
- ('Deployment', 'TEST-BASIC-scheduler'),
- ('Deployment', 'TEST-BASIC-statsd'),
- ('Deployment', 'TEST-BASIC-webserver'),
- ('StatefulSet', 'TEST-BASIC-postgresql'),
- ('Secret', 'TEST-BASIC-fernet-key'),
- ('Job', 'TEST-BASIC-create-user'),
- ('Job', 'TEST-BASIC-run-airflow-migrations'),
- ],
- )
- self.assertEqual(OBJECT_COUNT_IN_BASIC_DEPLOYMENT, len(k8s_objects))
+ assert list_of_kind_names_tuples == [
+ ('ServiceAccount', 'TEST-BASIC-scheduler'),
+ ('ServiceAccount', 'TEST-BASIC-webserver'),
+ ('ServiceAccount', 'TEST-BASIC-worker'),
+ ('Secret', 'TEST-BASIC-postgresql'),
+ ('Secret', 'TEST-BASIC-airflow-metadata'),
+ ('Secret', 'TEST-BASIC-airflow-result-backend'),
+ ('ConfigMap', 'TEST-BASIC-airflow-config'),
+ ('Role', 'TEST-BASIC-pod-launcher-role'),
+ ('Role', 'TEST-BASIC-pod-log-reader-role'),
+ ('RoleBinding', 'TEST-BASIC-pod-launcher-rolebinding'),
+ ('RoleBinding', 'TEST-BASIC-pod-log-reader-rolebinding'),
+ ('Service', 'TEST-BASIC-postgresql-headless'),
+ ('Service', 'TEST-BASIC-postgresql'),
+ ('Service', 'TEST-BASIC-statsd'),
+ ('Service', 'TEST-BASIC-webserver'),
+ ('Deployment', 'TEST-BASIC-scheduler'),
+ ('Deployment', 'TEST-BASIC-statsd'),
+ ('Deployment', 'TEST-BASIC-webserver'),
+ ('StatefulSet', 'TEST-BASIC-postgresql'),
+ ('Secret', 'TEST-BASIC-fernet-key'),
+ ('Job', 'TEST-BASIC-create-user'),
+ ('Job', 'TEST-BASIC-run-airflow-migrations'),
+ ]
+ assert OBJECT_COUNT_IN_BASIC_DEPLOYMENT == len(k8s_objects)
for k8s_object in k8s_objects:
labels = jmespath.search('metadata.labels', k8s_object) or {}
if 'postgresql' in labels.get('chart'):
continue
k8s_name = k8s_object['kind'] + ":" + k8s_object['metadata']['name']
- self.assertEqual(
- 'TEST-VALUE',
- labels.get("TEST-LABEL"),
- f"Missing label TEST-LABEL on {k8s_name}. Current labels: {labels}",
- )
+ assert 'TEST-VALUE' == labels.get(
+ "TEST-LABEL"
+ ), f"Missing label TEST-LABEL on {k8s_name}. Current labels: {labels}"
def test_basic_deployment_without_default_users(self):
k8s_objects = render_chart("TEST-BASIC", {"webserver": {'defaultUser': {'enabled': False}}})
list_of_kind_names_tuples = [
(k8s_object['kind'], k8s_object['metadata']['name']) for k8s_object in k8s_objects
]
- self.assertNotIn(('Job', 'TEST-BASIC-create-user'), list_of_kind_names_tuples)
- self.assertEqual(OBJECT_COUNT_IN_BASIC_DEPLOYMENT - 1, len(k8s_objects))
+ assert ('Job', 'TEST-BASIC-create-user') not in list_of_kind_names_tuples
+ assert OBJECT_COUNT_IN_BASIC_DEPLOYMENT - 1 == len(k8s_objects)
def test_network_policies_are_valid(self):
k8s_objects = render_chart(
@@ -109,7 +104,7 @@ def test_network_policies_are_valid(self):
('NetworkPolicy', 'TEST-BASIC-worker-policy'),
]
for kind_name in expected_kind_names:
- self.assertIn(kind_name, kind_names_tuples)
+ assert kind_name in kind_names_tuples
def test_chart_is_consistent_with_official_airflow_image(self):
def get_k8s_objs_with_image(obj: Union[List[Any], Dict[str, Any]]) -> List[Dict[str, Any]]:
@@ -137,4 +132,4 @@ def get_k8s_objs_with_image(obj: Union[List[Any], Dict[str, Any]]) -> List[Dict[
image: str = obj["image"] # pylint: disable=invalid-sequence-index
if image.startswith(image_repo):
# Make sure that a command is not specified
- self.assertNotIn("command", obj)
+ assert "command" not in obj
diff --git a/chart/tests/test_celery_kubernetes_executor.py b/chart/tests/test_celery_kubernetes_executor.py
index 6c54e804ea514..6e22ad2446d5d 100644
--- a/chart/tests/test_celery_kubernetes_executor.py
+++ b/chart/tests/test_celery_kubernetes_executor.py
@@ -32,8 +32,8 @@ def test_should_create_a_worker_deployment_with_the_celery_executor(self):
show_only=["templates/workers/worker-deployment.yaml"],
)
- self.assertEqual("config", jmespath.search("spec.template.spec.volumes[0].name", docs[0]))
- self.assertEqual("dags", jmespath.search("spec.template.spec.volumes[1].name", docs[0]))
+ assert "config" == jmespath.search("spec.template.spec.volumes[0].name", docs[0])
+ assert "dags" == jmespath.search("spec.template.spec.volumes[1].name", docs[0])
def test_should_create_a_worker_deployment_with_the_celery_kubernetes_executor(self):
docs = render_chart(
@@ -44,5 +44,5 @@ def test_should_create_a_worker_deployment_with_the_celery_kubernetes_executor(s
show_only=["templates/workers/worker-deployment.yaml"],
)
- self.assertEqual("config", jmespath.search("spec.template.spec.volumes[0].name", docs[0]))
- self.assertEqual("dags", jmespath.search("spec.template.spec.volumes[1].name", docs[0]))
+ assert "config" == jmespath.search("spec.template.spec.volumes[0].name", docs[0])
+ assert "dags" == jmespath.search("spec.template.spec.volumes[1].name", docs[0])
diff --git a/chart/tests/test_cleanup_pods.py b/chart/tests/test_cleanup_pods.py
index df2b3a0c4de0b..68f5b7c62c3f5 100644
--- a/chart/tests/test_cleanup_pods.py
+++ b/chart/tests/test_cleanup_pods.py
@@ -31,27 +31,21 @@ def test_should_create_cronjob_for_enabled_cleanup(self):
show_only=["templates/cleanup/cleanup-cronjob.yaml"],
)
- self.assertEqual(
- "airflow-cleanup-pods",
- jmespath.search("spec.jobTemplate.spec.template.spec.containers[0].name", docs[0]),
+ assert "airflow-cleanup-pods" == jmespath.search(
+ "spec.jobTemplate.spec.template.spec.containers[0].name", docs[0]
)
- self.assertEqual(
- "apache/airflow:2.0.0",
- jmespath.search("spec.jobTemplate.spec.template.spec.containers[0].image", docs[0]),
+ assert "apache/airflow:2.0.0" == jmespath.search(
+ "spec.jobTemplate.spec.template.spec.containers[0].image", docs[0]
)
- self.assertIn(
- {"name": "config", "configMap": {"name": "RELEASE-NAME-airflow-config"}},
- jmespath.search("spec.jobTemplate.spec.template.spec.volumes", docs[0]),
- )
- self.assertIn(
- {
- "name": "config",
- "mountPath": "/opt/airflow/airflow.cfg",
- "subPath": "airflow.cfg",
- "readOnly": True,
- },
- jmespath.search("spec.jobTemplate.spec.template.spec.containers[0].volumeMounts", docs[0]),
+ assert {"name": "config", "configMap": {"name": "RELEASE-NAME-airflow-config"}} in jmespath.search(
+ "spec.jobTemplate.spec.template.spec.volumes", docs[0]
)
+ assert {
+ "name": "config",
+ "mountPath": "/opt/airflow/airflow.cfg",
+ "subPath": "airflow.cfg",
+ "readOnly": True,
+ } in jmespath.search("spec.jobTemplate.spec.template.spec.containers[0].volumeMounts", docs[0])
def test_should_change_image_when_set_airflow_image(self):
docs = render_chart(
@@ -62,7 +56,6 @@ def test_should_change_image_when_set_airflow_image(self):
show_only=["templates/cleanup/cleanup-cronjob.yaml"],
)
- self.assertEqual(
- "airflow:test",
- jmespath.search("spec.jobTemplate.spec.template.spec.containers[0].image", docs[0]),
+ assert "airflow:test" == jmespath.search(
+ "spec.jobTemplate.spec.template.spec.containers[0].image", docs[0]
)
diff --git a/chart/tests/test_dags_persistent_volume_claim.py b/chart/tests/test_dags_persistent_volume_claim.py
index 946c40fe7c319..b0035c77d5632 100644
--- a/chart/tests/test_dags_persistent_volume_claim.py
+++ b/chart/tests/test_dags_persistent_volume_claim.py
@@ -29,7 +29,7 @@ def test_should_not_generate_a_document_if_persistence_is_disabled(self):
show_only=["templates/dags-persistent-volume-claim.yaml"],
)
- self.assertEqual(0, len(docs))
+ assert 0 == len(docs)
def test_should_not_generate_a_document_when_using_an_existing_claim(self):
docs = render_chart(
@@ -37,7 +37,7 @@ def test_should_not_generate_a_document_when_using_an_existing_claim(self):
show_only=["templates/dags-persistent-volume-claim.yaml"],
)
- self.assertEqual(0, len(docs))
+ assert 0 == len(docs)
def test_should_generate_a_document_if_persistence_is_enabled_and_not_using_an_existing_claim(self):
docs = render_chart(
@@ -45,7 +45,7 @@ def test_should_generate_a_document_if_persistence_is_enabled_and_not_using_an_e
show_only=["templates/dags-persistent-volume-claim.yaml"],
)
- self.assertEqual(1, len(docs))
+ assert 1 == len(docs)
def test_should_set_pvc_details_correctly(self):
docs = render_chart(
@@ -63,11 +63,8 @@ def test_should_set_pvc_details_correctly(self):
show_only=["templates/dags-persistent-volume-claim.yaml"],
)
- self.assertEqual(
- {
- "accessModes": ["ReadWriteMany"],
- "resources": {"requests": {"storage": "1G"}},
- "storageClassName": "MyStorageClass",
- },
- jmespath.search("spec", docs[0]),
- )
+ assert {
+ "accessModes": ["ReadWriteMany"],
+ "resources": {"requests": {"storage": "1G"}},
+ "storageClassName": "MyStorageClass",
+ } == jmespath.search("spec", docs[0])
diff --git a/chart/tests/test_extra_configmaps_secrets.py b/chart/tests/test_extra_configmaps_secrets.py
index 378d80ed4d261..88fb77aba232f 100644
--- a/chart/tests/test_extra_configmaps_secrets.py
+++ b/chart/tests/test_extra_configmaps_secrets.py
@@ -50,7 +50,7 @@ def test_extra_configmaps(self):
("ConfigMap", f"{RELEASE_NAME}-airflow-variables"),
("ConfigMap", f"{RELEASE_NAME}-other-variables"),
]
- self.assertEqual(set(k8s_objects_by_key.keys()), set(all_expected_keys))
+ assert set(k8s_objects_by_key.keys()) == set(all_expected_keys)
all_expected_data = [
{"AIRFLOW_VAR_HELLO_MESSAGE": "Hi!", "AIRFLOW_VAR_KUBERNETES_NAMESPACE": "default"},
@@ -58,7 +58,7 @@ def test_extra_configmaps(self):
]
for expected_key, expected_data in zip(all_expected_keys, all_expected_data):
configmap_obj = k8s_objects_by_key[expected_key]
- self.assertEqual(configmap_obj["data"], expected_data)
+ assert configmap_obj["data"] == expected_data
def test_extra_secrets(self):
values_str = textwrap.dedent(
@@ -88,7 +88,7 @@ def test_extra_secrets(self):
("Secret", f"{RELEASE_NAME}-airflow-connections"),
("Secret", f"{RELEASE_NAME}-other-secrets"),
]
- self.assertEqual(set(k8s_objects_by_key.keys()), set(all_expected_keys))
+ assert set(k8s_objects_by_key.keys()) == set(all_expected_keys)
all_expected_data = [
{"AIRFLOW_CON_AWS": b64encode(b"aws_connection_string").decode("utf-8")},
@@ -106,5 +106,5 @@ def test_extra_secrets(self):
all_expected_keys, all_expected_data, all_expected_string_data
):
configmap_obj = k8s_objects_by_key[expected_key]
- self.assertEqual(configmap_obj["data"], expected_data)
- self.assertEqual(configmap_obj["stringData"], expected_string_data)
+ assert configmap_obj["data"] == expected_data
+ assert configmap_obj["stringData"] == expected_string_data
diff --git a/chart/tests/test_extra_env_env_from.py b/chart/tests/test_extra_env_env_from.py
index 7e1b28df58efe..a2ac58f165666 100644
--- a/chart/tests/test_extra_env_env_from.py
+++ b/chart/tests/test_extra_env_env_from.py
@@ -98,7 +98,7 @@ def test_extra_env(self, k8s_obj_key, env_paths):
k8s_object = self.k8s_objects_by_key[k8s_obj_key]
for path in env_paths:
env = jmespath.search(f"{path}.env", k8s_object)
- self.assertIn(expected_env_as_str, yaml.dump(env))
+ assert expected_env_as_str in yaml.dump(env)
@parameterized.expand(PARAMS)
def test_extra_env_from(self, k8s_obj_key, env_from_paths):
@@ -114,4 +114,4 @@ def test_extra_env_from(self, k8s_obj_key, env_from_paths):
k8s_object = self.k8s_objects_by_key[k8s_obj_key]
for path in env_from_paths:
env_from = jmespath.search(f"{path}.envFrom", k8s_object)
- self.assertIn(expected_env_from_as_str, yaml.dump(env_from))
+ assert expected_env_from_as_str in yaml.dump(env_from)
diff --git a/chart/tests/test_flower_authorization.py b/chart/tests/test_flower_authorization.py
index 0520ddd28f273..4ef4db96f3427 100644
--- a/chart/tests/test_flower_authorization.py
+++ b/chart/tests/test_flower_authorization.py
@@ -33,17 +33,14 @@ def test_should_create_flower_deployment_with_authorization(self):
show_only=["templates/flower/flower-deployment.yaml"],
)
- self.assertEqual(
- "AIRFLOW__CELERY__FLOWER_BASIC_AUTH",
- jmespath.search("spec.template.spec.containers[0].env[0].name", docs[0]),
+ assert "AIRFLOW__CELERY__FLOWER_BASIC_AUTH" == jmespath.search(
+ "spec.template.spec.containers[0].env[0].name", docs[0]
)
- self.assertEqual(
- ['curl', '--user', '$AIRFLOW__CELERY__FLOWER_BASIC_AUTH', 'localhost:7777'],
- jmespath.search("spec.template.spec.containers[0].livenessProbe.exec.command", docs[0]),
+ assert ['curl', '--user', '$AIRFLOW__CELERY__FLOWER_BASIC_AUTH', 'localhost:7777'] == jmespath.search(
+ "spec.template.spec.containers[0].livenessProbe.exec.command", docs[0]
)
- self.assertEqual(
- ['curl', '--user', '$AIRFLOW__CELERY__FLOWER_BASIC_AUTH', 'localhost:7777'],
- jmespath.search("spec.template.spec.containers[0].readinessProbe.exec.command", docs[0]),
+ assert ['curl', '--user', '$AIRFLOW__CELERY__FLOWER_BASIC_AUTH', 'localhost:7777'] == jmespath.search(
+ "spec.template.spec.containers[0].readinessProbe.exec.command", docs[0]
)
def test_should_create_flower_deployment_without_authorization(self):
@@ -55,15 +52,12 @@ def test_should_create_flower_deployment_without_authorization(self):
show_only=["templates/flower/flower-deployment.yaml"],
)
- self.assertEqual(
- "AIRFLOW__CORE__FERNET_KEY",
- jmespath.search("spec.template.spec.containers[0].env[0].name", docs[0]),
+ assert "AIRFLOW__CORE__FERNET_KEY" == jmespath.search(
+ "spec.template.spec.containers[0].env[0].name", docs[0]
)
- self.assertEqual(
- ['curl', 'localhost:7777'],
- jmespath.search("spec.template.spec.containers[0].livenessProbe.exec.command", docs[0]),
+ assert ['curl', 'localhost:7777'] == jmespath.search(
+ "spec.template.spec.containers[0].livenessProbe.exec.command", docs[0]
)
- self.assertEqual(
- ['curl', 'localhost:7777'],
- jmespath.search("spec.template.spec.containers[0].readinessProbe.exec.command", docs[0]),
+ assert ['curl', 'localhost:7777'] == jmespath.search(
+ "spec.template.spec.containers[0].readinessProbe.exec.command", docs[0]
)
diff --git a/chart/tests/test_git_sync_scheduler.py b/chart/tests/test_git_sync_scheduler.py
index 58ea1c73e6523..1bfdf271fa9d8 100644
--- a/chart/tests/test_git_sync_scheduler.py
+++ b/chart/tests/test_git_sync_scheduler.py
@@ -29,7 +29,7 @@ def test_should_add_dags_volume(self):
show_only=["templates/scheduler/scheduler-deployment.yaml"],
)
- self.assertEqual("dags", jmespath.search("spec.template.spec.volumes[1].name", docs[0]))
+ assert "dags" == jmespath.search("spec.template.spec.volumes[1].name", docs[0])
def test_validate_the_git_sync_container_spec(self):
docs = render_chart(
@@ -64,27 +64,24 @@ def test_validate_the_git_sync_container_spec(self):
show_only=["templates/scheduler/scheduler-deployment.yaml"],
)
- self.assertEqual(
- {
- "name": "git-sync-test",
- "securityContext": {"runAsUser": 65533},
- "image": "test-registry/test-repo:test-tag",
- "imagePullPolicy": "Allways",
- "env": [
- {"name": "GIT_SYNC_REV", "value": "HEAD"},
- {"name": "GIT_SYNC_BRANCH", "value": "test-branch"},
- {"name": "GIT_SYNC_REPO", "value": "https://github.com/apache/airflow.git"},
- {"name": "GIT_SYNC_DEPTH", "value": "1"},
- {"name": "GIT_SYNC_ROOT", "value": "/git-root"},
- {"name": "GIT_SYNC_DEST", "value": "test-dest"},
- {"name": "GIT_SYNC_ADD_USER", "value": "true"},
- {"name": "GIT_SYNC_WAIT", "value": "66"},
- {"name": "GIT_SYNC_MAX_SYNC_FAILURES", "value": "70"},
- ],
- "volumeMounts": [{"mountPath": "/git-root", "name": "dags"}],
- },
- jmespath.search("spec.template.spec.containers[1]", docs[0]),
- )
+ assert {
+ "name": "git-sync-test",
+ "securityContext": {"runAsUser": 65533},
+ "image": "test-registry/test-repo:test-tag",
+ "imagePullPolicy": "Allways",
+ "env": [
+ {"name": "GIT_SYNC_REV", "value": "HEAD"},
+ {"name": "GIT_SYNC_BRANCH", "value": "test-branch"},
+ {"name": "GIT_SYNC_REPO", "value": "https://github.com/apache/airflow.git"},
+ {"name": "GIT_SYNC_DEPTH", "value": "1"},
+ {"name": "GIT_SYNC_ROOT", "value": "/git-root"},
+ {"name": "GIT_SYNC_DEST", "value": "test-dest"},
+ {"name": "GIT_SYNC_ADD_USER", "value": "true"},
+ {"name": "GIT_SYNC_WAIT", "value": "66"},
+ {"name": "GIT_SYNC_MAX_SYNC_FAILURES", "value": "70"},
+ ],
+ "volumeMounts": [{"mountPath": "/git-root", "name": "dags"}],
+ } == jmespath.search("spec.template.spec.containers[1]", docs[0])
def test_validate_if_ssh_params_are_added(self):
docs = render_chart(
@@ -102,22 +99,19 @@ def test_validate_if_ssh_params_are_added(self):
show_only=["templates/scheduler/scheduler-deployment.yaml"],
)
- self.assertIn(
- {"name": "GIT_SSH_KEY_FILE", "value": "/etc/git-secret/ssh"},
- jmespath.search("spec.template.spec.containers[1].env", docs[0]),
+ assert {"name": "GIT_SSH_KEY_FILE", "value": "/etc/git-secret/ssh"} in jmespath.search(
+ "spec.template.spec.containers[1].env", docs[0]
)
- self.assertIn(
- {"name": "GIT_SYNC_SSH", "value": "true"},
- jmespath.search("spec.template.spec.containers[1].env", docs[0]),
+ assert {"name": "GIT_SYNC_SSH", "value": "true"} in jmespath.search(
+ "spec.template.spec.containers[1].env", docs[0]
)
- self.assertIn(
- {"name": "GIT_KNOWN_HOSTS", "value": "false"},
- jmespath.search("spec.template.spec.containers[1].env", docs[0]),
- )
- self.assertIn(
- {"name": "git-sync-ssh-key", "secret": {"secretName": "ssh-secret", "defaultMode": 288}},
- jmespath.search("spec.template.spec.volumes", docs[0]),
+ assert {"name": "GIT_KNOWN_HOSTS", "value": "false"} in jmespath.search(
+ "spec.template.spec.containers[1].env", docs[0]
)
+ assert {
+ "name": "git-sync-ssh-key",
+ "secret": {"secretName": "ssh-secret", "defaultMode": 288},
+ } in jmespath.search("spec.template.spec.volumes", docs[0])
def test_should_set_username_and_pass_env_variables(self):
docs = render_chart(
@@ -133,20 +127,14 @@ def test_should_set_username_and_pass_env_variables(self):
show_only=["templates/scheduler/scheduler-deployment.yaml"],
)
- self.assertIn(
- {
- "name": "GIT_SYNC_USERNAME",
- "valueFrom": {"secretKeyRef": {"name": "user-pass-secret", "key": "GIT_SYNC_USERNAME"}},
- },
- jmespath.search("spec.template.spec.containers[1].env", docs[0]),
- )
- self.assertIn(
- {
- "name": "GIT_SYNC_PASSWORD",
- "valueFrom": {"secretKeyRef": {"name": "user-pass-secret", "key": "GIT_SYNC_PASSWORD"}},
- },
- jmespath.search("spec.template.spec.containers[1].env", docs[0]),
- )
+ assert {
+ "name": "GIT_SYNC_USERNAME",
+ "valueFrom": {"secretKeyRef": {"name": "user-pass-secret", "key": "GIT_SYNC_USERNAME"}},
+ } in jmespath.search("spec.template.spec.containers[1].env", docs[0])
+ assert {
+ "name": "GIT_SYNC_PASSWORD",
+ "valueFrom": {"secretKeyRef": {"name": "user-pass-secret", "key": "GIT_SYNC_PASSWORD"}},
+ } in jmespath.search("spec.template.spec.containers[1].env", docs[0])
def test_should_set_the_volume_claim_correctly_when_using_an_existing_claim(self):
docs = render_chart(
@@ -154,9 +142,8 @@ def test_should_set_the_volume_claim_correctly_when_using_an_existing_claim(self
show_only=["templates/scheduler/scheduler-deployment.yaml"],
)
- self.assertIn(
- {"name": "dags", "persistentVolumeClaim": {"claimName": "test-claim"}},
- jmespath.search("spec.template.spec.volumes", docs[0]),
+ assert {"name": "dags", "persistentVolumeClaim": {"claimName": "test-claim"}} in jmespath.search(
+ "spec.template.spec.volumes", docs[0]
)
def test_should_add_extra_volume_and_extra_volume_mount(self):
@@ -176,10 +163,9 @@ def test_should_add_extra_volume_and_extra_volume_mount(self):
show_only=["templates/scheduler/scheduler-deployment.yaml"],
)
- self.assertIn(
- {"name": "test-volume", "emptyDir": {}}, jmespath.search("spec.template.spec.volumes", docs[0])
+ assert {"name": "test-volume", "emptyDir": {}} in jmespath.search(
+ "spec.template.spec.volumes", docs[0]
)
- self.assertIn(
- {"name": "test-volume", "mountPath": "/opt/test"},
- jmespath.search("spec.template.spec.containers[0].volumeMounts", docs[0]),
+ assert {"name": "test-volume", "mountPath": "/opt/test"} in jmespath.search(
+ "spec.template.spec.containers[0].volumeMounts", docs[0]
)
diff --git a/chart/tests/test_git_sync_webserver.py b/chart/tests/test_git_sync_webserver.py
index 09c9aa3acf300..a232287bda554 100644
--- a/chart/tests/test_git_sync_webserver.py
+++ b/chart/tests/test_git_sync_webserver.py
@@ -29,7 +29,7 @@ def test_should_add_dags_volume_to_the_webserver_if_git_sync_and_persistence_is_
show_only=["templates/webserver/webserver-deployment.yaml"],
)
- self.assertEqual("dags", jmespath.search("spec.template.spec.volumes[1].name", docs[0]))
+ assert "dags" == jmespath.search("spec.template.spec.volumes[1].name", docs[0])
def test_should_add_dags_volume_to_the_webserver_if_git_sync_is_enabled_and_persistence_is_disabled(self):
docs = render_chart(
@@ -37,7 +37,7 @@ def test_should_add_dags_volume_to_the_webserver_if_git_sync_is_enabled_and_pers
show_only=["templates/webserver/webserver-deployment.yaml"],
)
- self.assertEqual("dags", jmespath.search("spec.template.spec.volumes[1].name", docs[0]))
+ assert "dags" == jmespath.search("spec.template.spec.volumes[1].name", docs[0])
def test_should_add_git_sync_container_to_webserver_if_persistence_is_not_enabled_but_git_sync_is(self):
docs = render_chart(
@@ -50,7 +50,7 @@ def test_should_add_git_sync_container_to_webserver_if_persistence_is_not_enable
show_only=["templates/webserver/webserver-deployment.yaml"],
)
- self.assertEqual("git-sync", jmespath.search("spec.template.spec.containers[0].name", docs[0]))
+ assert "git-sync" == jmespath.search("spec.template.spec.containers[0].name", docs[0])
def test_should_have_service_account_defined(self):
docs = render_chart(
@@ -58,6 +58,4 @@ def test_should_have_service_account_defined(self):
show_only=["templates/webserver/webserver-deployment.yaml"],
)
- self.assertEqual(
- "RELEASE-NAME-webserver", jmespath.search("spec.template.spec.serviceAccountName", docs[0])
- )
+ assert "RELEASE-NAME-webserver" == jmespath.search("spec.template.spec.serviceAccountName", docs[0])
diff --git a/chart/tests/test_git_sync_worker.py b/chart/tests/test_git_sync_worker.py
index a56b0dc808c6d..c48d00151872b 100644
--- a/chart/tests/test_git_sync_worker.py
+++ b/chart/tests/test_git_sync_worker.py
@@ -32,8 +32,8 @@ def test_should_add_dags_volume_to_the_worker_if_git_sync_and_persistence_is_ena
show_only=["templates/workers/worker-deployment.yaml"],
)
- self.assertEqual("config", jmespath.search("spec.template.spec.volumes[0].name", docs[0]))
- self.assertEqual("dags", jmespath.search("spec.template.spec.volumes[1].name", docs[0]))
+ assert "config" == jmespath.search("spec.template.spec.volumes[0].name", docs[0])
+ assert "dags" == jmespath.search("spec.template.spec.volumes[1].name", docs[0])
def test_should_add_dags_volume_to_the_worker_if_git_sync_is_enabled_and_persistence_is_disabled(self):
docs = render_chart(
@@ -44,8 +44,8 @@ def test_should_add_dags_volume_to_the_worker_if_git_sync_is_enabled_and_persist
show_only=["templates/workers/worker-deployment.yaml"],
)
- self.assertEqual("config", jmespath.search("spec.template.spec.volumes[0].name", docs[0]))
- self.assertEqual("dags", jmespath.search("spec.template.spec.volumes[1].name", docs[0]))
+ assert "config" == jmespath.search("spec.template.spec.volumes[0].name", docs[0])
+ assert "dags" == jmespath.search("spec.template.spec.volumes[1].name", docs[0])
def test_should_add_git_sync_container_to_worker_if_persistence_is_not_enabled_but_git_sync_is(self):
docs = render_chart(
@@ -59,7 +59,7 @@ def test_should_add_git_sync_container_to_worker_if_persistence_is_not_enabled_b
show_only=["templates/workers/worker-deployment.yaml"],
)
- self.assertEqual("git-sync", jmespath.search("spec.template.spec.containers[0].name", docs[0]))
+ assert "git-sync" == jmespath.search("spec.template.spec.containers[0].name", docs[0])
def test_should_not_add_sync_container_to_worker_if_git_sync_and_persistence_are_enabled(self):
docs = render_chart(
@@ -73,4 +73,4 @@ def test_should_not_add_sync_container_to_worker_if_git_sync_and_persistence_are
show_only=["templates/workers/worker-deployment.yaml"],
)
- self.assertNotEqual("git-sync", jmespath.search("spec.template.spec.containers[0].name", docs[0]))
+ assert "git-sync" != jmespath.search("spec.template.spec.containers[0].name", docs[0])
diff --git a/chart/tests/test_ingress_web.py b/chart/tests/test_ingress_web.py
index 7fcedaf4fc89c..6d64b8b4b09ab 100644
--- a/chart/tests/test_ingress_web.py
+++ b/chart/tests/test_ingress_web.py
@@ -34,4 +34,4 @@ def test_should_allow_more_than_one_annotation(self):
values={"ingress": {"enabled": True, "web": {"annotations": {"aa": "bb", "cc": "dd"}}}},
show_only=["templates/webserver/webserver-ingress.yaml"],
)
- self.assertEqual({"aa": "bb", "cc": "dd"}, jmespath.search("metadata.annotations", docs[0]))
+ assert {"aa": "bb", "cc": "dd"} == jmespath.search("metadata.annotations", docs[0])
diff --git a/chart/tests/test_keda.py b/chart/tests/test_keda.py
index 57da31a734a83..132439da40cc9 100644
--- a/chart/tests/test_keda.py
+++ b/chart/tests/test_keda.py
@@ -30,7 +30,7 @@ def test_keda_disabled_by_default(self):
show_only=["templates/workers/worker-kedaautoscaler.yaml"],
validate_schema=False,
)
- self.assertListEqual(docs, [])
+ assert docs == []
@parameterized.expand(
[
@@ -52,6 +52,6 @@ def test_keda_enabled(self, executor, is_created):
validate_schema=False,
)
if is_created:
- self.assertEqual("RELEASE-NAME-worker", jmespath.search("metadata.name", docs[0]))
+ assert "RELEASE-NAME-worker" == jmespath.search("metadata.name", docs[0])
else:
- self.assertListEqual(docs, [])
+ assert docs == []
diff --git a/chart/tests/test_kerberos.py b/chart/tests/test_kerberos.py
index b0cf88d7f70db..4676f65a393a9 100644
--- a/chart/tests/test_kerberos.py
+++ b/chart/tests/test_kerberos.py
@@ -29,4 +29,4 @@ def test_kerberos_not_mentioned_in_render_if_disabled(self):
obj for obj in k8s_objects if obj["metadata"]["name"] != "NO-KERBEROS-airflow-config"
]
k8s_objects_to_consider_str = json.dumps(k8s_objects_to_consider)
- self.assertNotIn("kerberos", k8s_objects_to_consider_str)
+ assert "kerberos" not in k8s_objects_to_consider_str
diff --git a/chart/tests/test_migrate_database_job.py b/chart/tests/test_migrate_database_job.py
index 4b92acaf60100..d1e922513243a 100644
--- a/chart/tests/test_migrate_database_job.py
+++ b/chart/tests/test_migrate_database_job.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
+import re
import unittest
import jmespath
@@ -29,7 +30,5 @@ def test_should_run_by_default(self):
show_only=["templates/migrate-database-job.yaml"],
)
- self.assertRegex(docs[0]["kind"], "Job")
- self.assertEqual(
- "run-airflow-migrations", jmespath.search("spec.template.spec.containers[0].name", docs[0])
- )
+ assert re.search("Job", docs[0]["kind"])
+ assert "run-airflow-migrations" == jmespath.search("spec.template.spec.containers[0].name", docs[0])
diff --git a/chart/tests/test_pod_launcher_role.py b/chart/tests/test_pod_launcher_role.py
index aeca27ef68f8f..99bfe5e199159 100644
--- a/chart/tests/test_pod_launcher_role.py
+++ b/chart/tests/test_pod_launcher_role.py
@@ -45,6 +45,6 @@ def test_pod_launcher_role(self, executor, rbac, allow, expected_accounts):
)
if expected_accounts:
for idx, suffix in enumerate(expected_accounts):
- self.assertEqual(f"RELEASE-NAME-{suffix}", jmespath.search(f"subjects[{idx}].name", docs[0]))
+ assert f"RELEASE-NAME-{suffix}" == jmespath.search(f"subjects[{idx}].name", docs[0])
else:
- self.assertEqual([], docs)
+ assert [] == docs
diff --git a/chart/tests/test_pod_template_file.py b/chart/tests/test_pod_template_file.py
index bd0b7ac32595a..f58fd3bb0f9d9 100644
--- a/chart/tests/test_pod_template_file.py
+++ b/chart/tests/test_pod_template_file.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
+import re
import unittest
from os import remove
from os.path import dirname, realpath
@@ -43,9 +44,9 @@ def test_should_work(self):
show_only=["templates/pod-template-file.yaml"],
)
- self.assertRegex(docs[0]["kind"], "Pod")
- self.assertIsNotNone(jmespath.search("spec.containers[0].image", docs[0]))
- self.assertEqual("base", jmespath.search("spec.containers[0].name", docs[0]))
+ assert re.search("Pod", docs[0]["kind"])
+ assert jmespath.search("spec.containers[0].image", docs[0]) is not None
+ assert "base" == jmespath.search("spec.containers[0].name", docs[0])
def test_should_add_an_init_container_if_git_sync_is_true(self):
docs = render_chart(
@@ -79,29 +80,26 @@ def test_should_add_an_init_container_if_git_sync_is_true(self):
show_only=["templates/pod-template-file.yaml"],
)
- self.assertRegex(docs[0]["kind"], "Pod")
- self.assertEqual(
- {
- "name": "git-sync-test",
- "securityContext": {"runAsUser": 65533},
- "image": "test-registry/test-repo:test-tag",
- "imagePullPolicy": "Allways",
- "env": [
- {"name": "GIT_SYNC_REV", "value": "HEAD"},
- {"name": "GIT_SYNC_BRANCH", "value": "test-branch"},
- {"name": "GIT_SYNC_REPO", "value": "https://github.com/apache/airflow.git"},
- {"name": "GIT_SYNC_DEPTH", "value": "1"},
- {"name": "GIT_SYNC_ROOT", "value": "/git-root"},
- {"name": "GIT_SYNC_DEST", "value": "test-dest"},
- {"name": "GIT_SYNC_ADD_USER", "value": "true"},
- {"name": "GIT_SYNC_WAIT", "value": "66"},
- {"name": "GIT_SYNC_MAX_SYNC_FAILURES", "value": "70"},
- {"name": "GIT_SYNC_ONE_TIME", "value": "true"},
- ],
- "volumeMounts": [{"mountPath": "/git-root", "name": "dags"}],
- },
- jmespath.search("spec.initContainers[0]", docs[0]),
- )
+ assert re.search("Pod", docs[0]["kind"])
+ assert {
+ "name": "git-sync-test",
+ "securityContext": {"runAsUser": 65533},
+ "image": "test-registry/test-repo:test-tag",
+ "imagePullPolicy": "Allways",
+ "env": [
+ {"name": "GIT_SYNC_REV", "value": "HEAD"},
+ {"name": "GIT_SYNC_BRANCH", "value": "test-branch"},
+ {"name": "GIT_SYNC_REPO", "value": "https://github.com/apache/airflow.git"},
+ {"name": "GIT_SYNC_DEPTH", "value": "1"},
+ {"name": "GIT_SYNC_ROOT", "value": "/git-root"},
+ {"name": "GIT_SYNC_DEST", "value": "test-dest"},
+ {"name": "GIT_SYNC_ADD_USER", "value": "true"},
+ {"name": "GIT_SYNC_WAIT", "value": "66"},
+ {"name": "GIT_SYNC_MAX_SYNC_FAILURES", "value": "70"},
+ {"name": "GIT_SYNC_ONE_TIME", "value": "true"},
+ ],
+ "volumeMounts": [{"mountPath": "/git-root", "name": "dags"}],
+ } == jmespath.search("spec.initContainers[0]", docs[0])
def test_validate_if_ssh_params_are_added(self):
docs = render_chart(
@@ -119,21 +117,19 @@ def test_validate_if_ssh_params_are_added(self):
show_only=["templates/pod-template-file.yaml"],
)
- self.assertIn(
- {"name": "GIT_SSH_KEY_FILE", "value": "/etc/git-secret/ssh"},
- jmespath.search("spec.initContainers[0].env", docs[0]),
- )
- self.assertIn(
- {"name": "GIT_SYNC_SSH", "value": "true"}, jmespath.search("spec.initContainers[0].env", docs[0])
+ assert {"name": "GIT_SSH_KEY_FILE", "value": "/etc/git-secret/ssh"} in jmespath.search(
+ "spec.initContainers[0].env", docs[0]
)
- self.assertIn(
- {"name": "GIT_KNOWN_HOSTS", "value": "false"},
- jmespath.search("spec.initContainers[0].env", docs[0]),
+ assert {"name": "GIT_SYNC_SSH", "value": "true"} in jmespath.search(
+ "spec.initContainers[0].env", docs[0]
)
- self.assertIn(
- {"name": "git-sync-ssh-key", "secret": {"secretName": "ssh-secret", "defaultMode": 288}},
- jmespath.search("spec.volumes", docs[0]),
+ assert {"name": "GIT_KNOWN_HOSTS", "value": "false"} in jmespath.search(
+ "spec.initContainers[0].env", docs[0]
)
+ assert {
+ "name": "git-sync-ssh-key",
+ "secret": {"secretName": "ssh-secret", "defaultMode": 288},
+ } in jmespath.search("spec.volumes", docs[0])
def test_validate_if_ssh_known_hosts_are_added(self):
docs = render_chart(
@@ -150,25 +146,18 @@ def test_validate_if_ssh_known_hosts_are_added(self):
},
show_only=["templates/pod-template-file.yaml"],
)
- self.assertIn(
- {"name": "GIT_KNOWN_HOSTS", "value": "true"},
- jmespath.search("spec.initContainers[0].env", docs[0]),
- )
- self.assertIn(
- {
- "name": "git-sync-known-hosts",
- "configMap": {"defaultMode": 288, "name": "RELEASE-NAME-airflow-config"},
- },
- jmespath.search("spec.volumes", docs[0]),
- )
- self.assertIn(
- {
- "name": "git-sync-known-hosts",
- "mountPath": "/etc/git-secret/known_hosts",
- "subPath": "known_hosts",
- },
- jmespath.search("spec.containers[0].volumeMounts", docs[0]),
+ assert {"name": "GIT_KNOWN_HOSTS", "value": "true"} in jmespath.search(
+ "spec.initContainers[0].env", docs[0]
)
+ assert {
+ "name": "git-sync-known-hosts",
+ "configMap": {"defaultMode": 288, "name": "RELEASE-NAME-airflow-config"},
+ } in jmespath.search("spec.volumes", docs[0])
+ assert {
+ "name": "git-sync-known-hosts",
+ "mountPath": "/etc/git-secret/known_hosts",
+ "subPath": "known_hosts",
+ } in jmespath.search("spec.containers[0].volumeMounts", docs[0])
def test_should_set_username_and_pass_env_variables(self):
docs = render_chart(
@@ -184,20 +173,14 @@ def test_should_set_username_and_pass_env_variables(self):
show_only=["templates/pod-template-file.yaml"],
)
- self.assertIn(
- {
- "name": "GIT_SYNC_USERNAME",
- "valueFrom": {"secretKeyRef": {"name": "user-pass-secret", "key": "GIT_SYNC_USERNAME"}},
- },
- jmespath.search("spec.initContainers[0].env", docs[0]),
- )
- self.assertIn(
- {
- "name": "GIT_SYNC_PASSWORD",
- "valueFrom": {"secretKeyRef": {"name": "user-pass-secret", "key": "GIT_SYNC_PASSWORD"}},
- },
- jmespath.search("spec.initContainers[0].env", docs[0]),
- )
+ assert {
+ "name": "GIT_SYNC_USERNAME",
+ "valueFrom": {"secretKeyRef": {"name": "user-pass-secret", "key": "GIT_SYNC_USERNAME"}},
+ } in jmespath.search("spec.initContainers[0].env", docs[0])
+ assert {
+ "name": "GIT_SYNC_PASSWORD",
+ "valueFrom": {"secretKeyRef": {"name": "user-pass-secret", "key": "GIT_SYNC_PASSWORD"}},
+ } in jmespath.search("spec.initContainers[0].env", docs[0])
def test_should_set_the_volume_claim_correctly_when_using_an_existing_claim(self):
docs = render_chart(
@@ -205,9 +188,8 @@ def test_should_set_the_volume_claim_correctly_when_using_an_existing_claim(self
show_only=["templates/pod-template-file.yaml"],
)
- self.assertIn(
- {"name": "dags", "persistentVolumeClaim": {"claimName": "test-claim"}},
- jmespath.search("spec.volumes", docs[0]),
+ assert {"name": "dags", "persistentVolumeClaim": {"claimName": "test-claim"}} in jmespath.search(
+ "spec.volumes", docs[0]
)
def test_should_set_a_custom_image_in_pod_template(self):
@@ -216,9 +198,9 @@ def test_should_set_a_custom_image_in_pod_template(self):
show_only=["templates/pod-template-file.yaml"],
)
- self.assertRegex(docs[0]["kind"], "Pod")
- self.assertEqual("dummy_image:latest", jmespath.search("spec.containers[0].image", docs[0]))
- self.assertEqual("base", jmespath.search("spec.containers[0].name", docs[0]))
+ assert re.search("Pod", docs[0]["kind"])
+ assert "dummy_image:latest" == jmespath.search("spec.containers[0].image", docs[0])
+ assert "base" == jmespath.search("spec.containers[0].name", docs[0])
def test_mount_airflow_cfg(self):
docs = render_chart(
@@ -226,20 +208,16 @@ def test_mount_airflow_cfg(self):
show_only=["templates/pod-template-file.yaml"],
)
- self.assertRegex(docs[0]["kind"], "Pod")
- self.assertDictEqual(
- {'configMap': {'name': 'RELEASE-NAME-airflow-config'}, 'name': 'config'},
- jmespath.search("spec.volumes[1]", docs[0]),
- )
- self.assertDictEqual(
- {
- 'name': 'config',
- 'mountPath': '/opt/airflow/airflow.cfg',
- 'subPath': 'airflow.cfg',
- 'readOnly': True,
- },
- jmespath.search("spec.containers[0].volumeMounts[1]", docs[0]),
+ assert re.search("Pod", docs[0]["kind"])
+ assert {'configMap': {'name': 'RELEASE-NAME-airflow-config'}, 'name': 'config'} == jmespath.search(
+ "spec.volumes[1]", docs[0]
)
+ assert {
+ 'name': 'config',
+ 'mountPath': '/opt/airflow/airflow.cfg',
+ 'subPath': 'airflow.cfg',
+ 'readOnly': True,
+ } == jmespath.search("spec.containers[0].volumeMounts[1]", docs[0])
def test_should_create_valid_affinity_and_node_selector(self):
docs = render_chart(
@@ -265,29 +243,20 @@ def test_should_create_valid_affinity_and_node_selector(self):
show_only=["templates/pod-template-file.yaml"],
)
- self.assertRegex(docs[0]["kind"], "Pod")
- self.assertEqual(
- "foo",
- jmespath.search(
- "spec.affinity.nodeAffinity."
- "requiredDuringSchedulingIgnoredDuringExecution."
- "nodeSelectorTerms[0]."
- "matchExpressions[0]."
- "key",
- docs[0],
- ),
+ assert re.search("Pod", docs[0]["kind"])
+ assert "foo" == jmespath.search(
+ "spec.affinity.nodeAffinity."
+ "requiredDuringSchedulingIgnoredDuringExecution."
+ "nodeSelectorTerms[0]."
+ "matchExpressions[0]."
+ "key",
+ docs[0],
)
- self.assertEqual(
- "ssd",
- jmespath.search(
- "spec.nodeSelector.diskType",
- docs[0],
- ),
+ assert "ssd" == jmespath.search(
+ "spec.nodeSelector.diskType",
+ docs[0],
)
- self.assertEqual(
- "dynamic-pods",
- jmespath.search(
- "spec.tolerations[0].key",
- docs[0],
- ),
+ assert "dynamic-pods" == jmespath.search(
+ "spec.tolerations[0].key",
+ docs[0],
)
diff --git a/chart/tests/test_redis.py b/chart/tests/test_redis.py
index d4c5db791df43..a3e116862b062 100644
--- a/chart/tests/test_redis.py
+++ b/chart/tests/test_redis.py
@@ -14,11 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import re
import unittest
from base64 import b64decode
from subprocess import CalledProcessError
from typing import Optional
+import pytest
from parameterized import parameterized
from tests.helm_template_generator import prepare_k8s_lookup_dict, render_chart
@@ -67,16 +69,16 @@ def assert_password_and_broker_url_secrets(
):
if expected_password_match is not None:
redis_password_in_password_secret = self.get_redis_password_in_password_secret(k8s_obj_by_key)
- self.assertRegex(redis_password_in_password_secret, expected_password_match)
+ assert re.search(expected_password_match, redis_password_in_password_secret)
else:
- self.assertNotIn(REDIS_OBJECTS["SECRET_PASSWORD"], k8s_obj_by_key.keys())
+ assert REDIS_OBJECTS["SECRET_PASSWORD"] not in k8s_obj_by_key.keys()
if expected_broker_url_match is not None:
# assert redis broker url in secret
broker_url_in_broker_url_secret = self.get_broker_url_in_broker_url_secret(k8s_obj_by_key)
- self.assertRegex(broker_url_in_broker_url_secret, expected_broker_url_match)
+ assert re.search(expected_broker_url_match, broker_url_in_broker_url_secret)
else:
- self.assertNotIn(REDIS_OBJECTS["SECRET_BROKER_URL"], k8s_obj_by_key.keys())
+ assert REDIS_OBJECTS["SECRET_BROKER_URL"] not in k8s_obj_by_key.keys()
def assert_broker_url_env(
self, k8s_obj_by_key, expected_broker_url_secret_name=REDIS_OBJECTS["SECRET_BROKER_URL"][1]
@@ -84,11 +86,11 @@ def assert_broker_url_env(
broker_url_secret_in_scheduler = self.get_broker_url_secret_in_deployment(
k8s_obj_by_key, "StatefulSet", "worker"
)
- self.assertEqual(broker_url_secret_in_scheduler, expected_broker_url_secret_name)
+ assert broker_url_secret_in_scheduler == expected_broker_url_secret_name
broker_url_secret_in_worker = self.get_broker_url_secret_in_deployment(
k8s_obj_by_key, "Deployment", "scheduler"
)
- self.assertEqual(broker_url_secret_in_worker, expected_broker_url_secret_name)
+ assert broker_url_secret_in_worker == expected_broker_url_secret_name
@parameterized.expand(CELERY_EXECUTORS_PARAMS)
def test_redis_by_chart_default(self, executor):
@@ -103,7 +105,7 @@ def test_redis_by_chart_default(self, executor):
k8s_obj_by_key = prepare_k8s_lookup_dict(k8s_objects)
created_redis_objects = SET_POSSIBLE_REDIS_OBJECT_KEYS & set(k8s_obj_by_key.keys())
- self.assertEqual(created_redis_objects, SET_POSSIBLE_REDIS_OBJECT_KEYS)
+ assert created_redis_objects == SET_POSSIBLE_REDIS_OBJECT_KEYS
self.assert_password_and_broker_url_secrets(
k8s_obj_by_key,
@@ -126,7 +128,7 @@ def test_redis_by_chart_password(self, executor):
k8s_obj_by_key = prepare_k8s_lookup_dict(k8s_objects)
created_redis_objects = SET_POSSIBLE_REDIS_OBJECT_KEYS & set(k8s_obj_by_key.keys())
- self.assertEqual(created_redis_objects, SET_POSSIBLE_REDIS_OBJECT_KEYS)
+ assert created_redis_objects == SET_POSSIBLE_REDIS_OBJECT_KEYS
self.assert_password_and_broker_url_secrets(
k8s_obj_by_key,
@@ -138,7 +140,7 @@ def test_redis_by_chart_password(self, executor):
@parameterized.expand(CELERY_EXECUTORS_PARAMS)
def test_redis_by_chart_password_secret_name_missing_broker_url_secret_name(self, executor):
- with self.assertRaises(CalledProcessError):
+ with pytest.raises(CalledProcessError):
render_chart(
RELEASE_NAME_REDIS,
{
@@ -168,11 +170,10 @@ def test_redis_by_chart_password_secret_name(self, executor):
k8s_obj_by_key = prepare_k8s_lookup_dict(k8s_objects)
created_redis_objects = SET_POSSIBLE_REDIS_OBJECT_KEYS & set(k8s_obj_by_key.keys())
- self.assertEqual(
- created_redis_objects,
- SET_POSSIBLE_REDIS_OBJECT_KEYS
- - {REDIS_OBJECTS["SECRET_PASSWORD"], REDIS_OBJECTS["SECRET_BROKER_URL"]},
- )
+ assert created_redis_objects == SET_POSSIBLE_REDIS_OBJECT_KEYS - {
+ REDIS_OBJECTS["SECRET_PASSWORD"],
+ REDIS_OBJECTS["SECRET_BROKER_URL"],
+ }
self.assert_password_and_broker_url_secrets(
k8s_obj_by_key, expected_password_match=None, expected_broker_url_match=None
@@ -196,7 +197,7 @@ def test_external_redis_broker_url(self, executor):
k8s_obj_by_key = prepare_k8s_lookup_dict(k8s_objects)
created_redis_objects = SET_POSSIBLE_REDIS_OBJECT_KEYS & set(k8s_obj_by_key.keys())
- self.assertEqual(created_redis_objects, {REDIS_OBJECTS["SECRET_BROKER_URL"]})
+ assert created_redis_objects == {REDIS_OBJECTS["SECRET_BROKER_URL"]}
self.assert_password_and_broker_url_secrets(
k8s_obj_by_key,
@@ -221,7 +222,7 @@ def test_external_redis_broker_url_secret_name(self, executor):
k8s_obj_by_key = prepare_k8s_lookup_dict(k8s_objects)
created_redis_objects = SET_POSSIBLE_REDIS_OBJECT_KEYS & set(k8s_obj_by_key.keys())
- self.assertEqual(created_redis_objects, set())
+ assert created_redis_objects == set()
self.assert_password_and_broker_url_secrets(
k8s_obj_by_key, expected_password_match=None, expected_broker_url_match=None
diff --git a/chart/tests/test_scheduler.py b/chart/tests/test_scheduler.py
index eb5225e35c389..4621ce2d9a100 100644
--- a/chart/tests/test_scheduler.py
+++ b/chart/tests/test_scheduler.py
@@ -35,7 +35,7 @@ def test_should_add_extra_volume_and_extra_volume_mount(self):
show_only=["templates/scheduler/scheduler-deployment.yaml"],
)
- self.assertEqual("test-volume", jmespath.search("spec.template.spec.volumes[1].name", docs[0]))
- self.assertEqual(
- "test-volume", jmespath.search("spec.template.spec.containers[0].volumeMounts[3].name", docs[0])
+ assert "test-volume" == jmespath.search("spec.template.spec.volumes[1].name", docs[0])
+ assert "test-volume" == jmespath.search(
+ "spec.template.spec.containers[0].volumeMounts[3].name", docs[0]
)
diff --git a/chart/tests/test_worker.py b/chart/tests/test_worker.py
index 9b3515ef81562..414658907f3b7 100644
--- a/chart/tests/test_worker.py
+++ b/chart/tests/test_worker.py
@@ -35,7 +35,7 @@ def test_should_add_extra_volume_and_extra_volume_mount(self):
show_only=["templates/workers/worker-deployment.yaml"],
)
- self.assertEqual("test-volume", jmespath.search("spec.template.spec.volumes[0].name", docs[0]))
- self.assertEqual(
- "test-volume", jmespath.search("spec.template.spec.containers[0].volumeMounts[0].name", docs[0])
+ assert "test-volume" == jmespath.search("spec.template.spec.volumes[0].name", docs[0])
+ assert "test-volume" == jmespath.search(
+ "spec.template.spec.containers[0].volumeMounts[0].name", docs[0]
)
diff --git a/docs/apache-airflow/best-practices.rst b/docs/apache-airflow/best-practices.rst
index 228926957f204..2d5b409a0a4ea 100644
--- a/docs/apache-airflow/best-practices.rst
+++ b/docs/apache-airflow/best-practices.rst
@@ -145,9 +145,9 @@ Unit tests ensure that there is no incorrect code in your DAG. You can write uni
def test_dag_loaded(self):
dag = self.dagbag.get_dag(dag_id='hello_world')
- self.assertDictEqual(self.dagbag.import_errors, {})
- self.assertIsNotNone(dag)
- self.assertEqual(len(dag.tasks), 1)
+ assert self.dagbag.import_errors == {}
+ assert dag is not None
+ assert len(dag.tasks) == 1
**Unit test a DAG structure:**
This is an example test want to verify the structure of a code-generated DAG against a dict object
@@ -157,12 +157,11 @@ This is an example test want to verify the structure of a code-generated DAG aga
import unittest
class testClass(unittest.TestCase):
def assertDagDictEqual(self,source,dag):
- self.assertEqual(dag.task_dict.keys(),source.keys())
- for task_id,downstream_list in source.items():
- self.assertTrue(dag.has_task(task_id), msg="Missing task_id: {} in dag".format(task_id))
+ assert dag.task_dict.keys() == source.keys()
+ for task_id, downstream_list in source.items():
+ assert dag.has_task(task_id)
task = dag.get_task(task_id)
- self.assertEqual(task.downstream_task_ids, set(downstream_list),
- msg="unexpected downstream link in {}".format(task_id))
+ assert task.downstream_task_ids == set(downstream_list)
def test_dag(self):
self.assertDagDictEqual({
"DummyInstruction_0": ["DummyInstruction_1"],
@@ -193,8 +192,8 @@ This is an example test want to verify the structure of a code-generated DAG aga
def test_execute_no_trigger(self):
self.ti.run(ignore_ti_state=True)
- self.assertEqual(self.ti.state, State.SUCCESS)
- #Assert something related to tasks results
+ assert self.ti.state == State.SUCCESS
+ # Assert something related to tasks results
Self-Checks
------------
@@ -247,7 +246,7 @@ For variable, use :envvar:`AIRFLOW_VAR_{KEY}`.
.. code-block:: python
with mock.patch.dict('os.environ', AIRFLOW_VAR_KEY="env-value"):
- self.assertEqual("env-value", Variable.get("key"))
+ assert "env-value" == Variable.get("key")
For connection, use :envvar:`AIRFLOW_CONN_{CONN_ID}`.
@@ -260,4 +259,4 @@ For connection, use :envvar:`AIRFLOW_CONN_{CONN_ID}`.
)
conn_uri = conn.get_uri()
with mock.patch.dict("os.environ", AIRFLOW_CONN_MY_CONN=conn_uri):
- self.assertEqual("cat", Connection.get("my_conn").login)
+ assert "cat" == Connection.get("my_conn").login
diff --git a/kubernetes_tests/test_kubernetes_executor.py b/kubernetes_tests/test_kubernetes_executor.py
index 45f77ebef7997..bd16aefaacd13 100644
--- a/kubernetes_tests/test_kubernetes_executor.py
+++ b/kubernetes_tests/test_kubernetes_executor.py
@@ -84,7 +84,7 @@ def _ensure_airflow_webserver_is_healthy(self):
timeout=1,
)
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
def setUp(self):
self.session = self._get_session_with_retries()
@@ -112,7 +112,7 @@ def monitor_task(self, host, execution_date, dag_id, task_id, expected_final_sta
check_call(["echo", "api returned 404."])
tries += 1
continue
- self.assertEqual(result.status_code, 200, "Could not get the status")
+ assert result.status_code == 200, "Could not get the status"
result_json = result.json()
print(f"Received [monitor_task]#2: {result_json}")
state = result_json['state']
@@ -127,7 +127,7 @@ def monitor_task(self, host, execution_date, dag_id, task_id, expected_final_sta
check_call(["echo", f"api call failed. trying again. error {e}"])
if state != expected_final_state:
print(f"The expected state is wrong {state} != {expected_final_state} (expected)!")
- self.assertEqual(state, expected_final_state)
+ assert state == expected_final_state
def ensure_dag_expected_state(self, host, execution_date, dag_id, expected_final_state, timeout):
tries = 0
@@ -140,7 +140,7 @@ def ensure_dag_expected_state(self, host, execution_date, dag_id, expected_final
print(f"Calling {get_string}")
# Trigger a new dagrun
result = self.session.get(get_string)
- self.assertEqual(result.status_code, 200, "Could not get the status")
+ assert result.status_code == 200, "Could not get the status"
result_json = result.json()
print(f"Received: {result}")
state = result_json['state']
@@ -152,7 +152,7 @@ def ensure_dag_expected_state(self, host, execution_date, dag_id, expected_final
self._describe_resources("airflow")
self._describe_resources("default")
tries += 1
- self.assertEqual(state, expected_final_state)
+ assert state == expected_final_state
# Maybe check if we can retrieve the logs, but then we need to extend the API
@@ -165,7 +165,7 @@ def start_dag(self, dag_id, host):
except ValueError:
result_json = str(result)
print(f"Received [start_dag]#1 {result_json}")
- self.assertEqual(result.status_code, 200, f"Could not enable DAG: {result_json}")
+ assert result.status_code == 200, f"Could not enable DAG: {result_json}"
post_string = f'http://{host}/api/experimental/' f'dags/{dag_id}/dag_runs'
print(f"Calling [start_dag]#2 {post_string}")
# Trigger a new dagrun
@@ -175,17 +175,15 @@ def start_dag(self, dag_id, host):
except ValueError:
result_json = str(result)
print(f"Received [start_dag]#2 {result_json}")
- self.assertEqual(result.status_code, 200, f"Could not trigger a DAG-run: {result_json}")
+ assert result.status_code == 200, f"Could not trigger a DAG-run: {result_json}"
time.sleep(1)
get_string = f'http://{host}/api/experimental/latest_runs'
print(f"Calling [start_dag]#3 {get_string}")
result = self.session.get(get_string)
- self.assertEqual(
- result.status_code,
- 200,
- "Could not get the latest DAG-run:" " {result}".format(result=result.json()),
+ assert result.status_code == 200, "Could not get the latest DAG-run:" " {result}".format(
+ result=result.json()
)
result_json = result.json()
print(f"Received: [start_dag]#3 {result_json}")
@@ -193,13 +191,13 @@ def start_dag(self, dag_id, host):
def start_job_in_kubernetes(self, dag_id, host):
result_json = self.start_dag(dag_id=dag_id, host=host)
- self.assertGreater(len(result_json['items']), 0)
+ assert len(result_json['items']) > 0
execution_date = None
for dag_run in result_json['items']:
if dag_run['dag_id'] == dag_id:
execution_date = dag_run['execution_date']
break
- self.assertIsNotNone(execution_date, f"No execution_date can be found for the dag with {dag_id}")
+ assert execution_date is not None, f"No execution_date can be found for the dag with {dag_id}"
return execution_date
def test_integration_run_dag(self):
@@ -264,6 +262,4 @@ def test_integration_run_dag_with_scheduler_failure(self):
timeout=300,
)
- self.assertEqual(
- self._num_pods_in_namespace('test-namespace'), 0, "failed to delete pods in other namespace"
- )
+ assert self._num_pods_in_namespace('test-namespace') == 0, "failed to delete pods in other namespace"
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py
index 49754a2deab3e..574f7ffa3a578 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -27,6 +27,7 @@
from unittest.mock import ANY
import pendulum
+import pytest
from kubernetes.client import models as k8s
from kubernetes.client.api_client import ApiClient
from kubernetes.client.rest import ApiException
@@ -136,7 +137,7 @@ def test_do_xcom_push_defaults_false(self):
do_xcom_push=False,
config_file=new_config_path,
)
- self.assertFalse(k.do_xcom_push)
+ assert not k.do_xcom_push
def test_config_path_move(self):
new_config_path = '/tmp/kube_config'
@@ -158,7 +159,7 @@ def test_config_path_move(self):
context = create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_working_pod(self):
k = KubernetesPodOperator(
@@ -175,8 +176,8 @@ def test_working_pod(self):
context = create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
- self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
- self.assertEqual(self.expected_pod['metadata']['labels'], actual_pod['metadata']['labels'])
+ assert self.expected_pod['spec'] == actual_pod['spec']
+ assert self.expected_pod['metadata']['labels'] == actual_pod['metadata']['labels']
def test_delete_operator_pod(self):
k = KubernetesPodOperator(
@@ -194,8 +195,8 @@ def test_delete_operator_pod(self):
context = create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
- self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
- self.assertEqual(self.expected_pod['metadata']['labels'], actual_pod['metadata']['labels'])
+ assert self.expected_pod['spec'] == actual_pod['spec']
+ assert self.expected_pod['metadata']['labels'] == actual_pod['metadata']['labels']
def test_pod_hostnetwork(self):
k = KubernetesPodOperator(
@@ -214,8 +215,8 @@ def test_pod_hostnetwork(self):
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['hostNetwork'] = True
- self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
- self.assertEqual(self.expected_pod['metadata']['labels'], actual_pod['metadata']['labels'])
+ assert self.expected_pod['spec'] == actual_pod['spec']
+ assert self.expected_pod['metadata']['labels'] == actual_pod['metadata']['labels']
def test_pod_dnspolicy(self):
dns_policy = "ClusterFirstWithHostNet"
@@ -237,8 +238,8 @@ def test_pod_dnspolicy(self):
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['hostNetwork'] = True
self.expected_pod['spec']['dnsPolicy'] = dns_policy
- self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
- self.assertEqual(self.expected_pod['metadata']['labels'], actual_pod['metadata']['labels'])
+ assert self.expected_pod['spec'] == actual_pod['spec']
+ assert self.expected_pod['metadata']['labels'] == actual_pod['metadata']['labels']
def test_pod_schedulername(self):
scheduler_name = "default-scheduler"
@@ -258,7 +259,7 @@ def test_pod_schedulername(self):
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['schedulerName'] = scheduler_name
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_pod_node_selectors(self):
node_selectors = {'beta.kubernetes.io/os': 'linux'}
@@ -278,7 +279,7 @@ def test_pod_node_selectors(self):
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['nodeSelector'] = node_selectors
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_pod_resources(self):
resources = k8s.V1ResourceRequirements(
@@ -304,7 +305,7 @@ def test_pod_resources(self):
'requests': {'memory': '64Mi', 'cpu': '250m', 'ephemeral-storage': '1Gi'},
'limits': {'memory': '64Mi', 'cpu': 0.25, 'nvidia.com/gpu': None, 'ephemeral-storage': '2Gi'},
}
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_pod_affinity(self):
affinity = {
@@ -336,7 +337,7 @@ def test_pod_affinity(self):
k.execute(context=context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['affinity'] = affinity
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_port(self):
port = k8s.V1ContainerPort(
@@ -360,7 +361,7 @@ def test_port(self):
k.execute(context=context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['ports'] = [{'name': 'http', 'containerPort': 80}]
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_volume_mount(self):
with mock.patch.object(PodLauncher, 'log') as mock_logger:
@@ -401,7 +402,7 @@ def test_volume_mount(self):
self.expected_pod['spec']['volumes'] = [
{'name': 'test-volume', 'persistentVolumeClaim': {'claimName': 'test-volume'}}
]
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_run_as_user_root(self):
security_context = {
@@ -425,7 +426,7 @@ def test_run_as_user_root(self):
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['securityContext'] = security_context
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_run_as_user_non_root(self):
security_context = {
@@ -450,7 +451,7 @@ def test_run_as_user_non_root(self):
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['securityContext'] = security_context
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_fs_group(self):
security_context = {
@@ -475,7 +476,7 @@ def test_fs_group(self):
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['securityContext'] = security_context
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_faulty_image(self):
bad_image_name = "foobar"
@@ -491,12 +492,12 @@ def test_faulty_image(self):
do_xcom_push=False,
startup_timeout_seconds=5,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
context = create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['image'] = bad_image_name
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_faulty_service_account(self):
bad_service_account_name = "foobar"
@@ -513,12 +514,12 @@ def test_faulty_service_account(self):
startup_timeout_seconds=5,
service_account_name=bad_service_account_name,
)
- with self.assertRaises(ApiException):
+ with pytest.raises(ApiException):
context = create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['serviceAccountName'] = bad_service_account_name
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_pod_failure(self):
"""
@@ -536,12 +537,12 @@ def test_pod_failure(self):
in_cluster=False,
do_xcom_push=False,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
context = create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['args'] = bad_internal_command
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_xcom_push(self):
return_value = '{"foo": "bar"\n, "buzz": 2}'
@@ -558,7 +559,7 @@ def test_xcom_push(self):
do_xcom_push=True,
)
context = create_context(k)
- self.assertEqual(k.execute(context), json.loads(return_value))
+ assert k.execute(context) == json.loads(return_value)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
volume = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME)
volume_mount = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME_MOUNT)
@@ -567,7 +568,7 @@ def test_xcom_push(self):
self.expected_pod['spec']['containers'][0]['volumeMounts'].insert(0, volume_mount) # noqa
self.expected_pod['spec']['volumes'].insert(0, volume)
self.expected_pod['spec']['containers'].append(container)
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@@ -595,7 +596,7 @@ def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start):
mock_monitor.return_value = (State.SUCCESS, None)
context = create_context(k)
k.execute(context)
- self.assertEqual(mock_start.call_args[0][0].spec.containers[0].env_from, env_from)
+ assert mock_start.call_args[0][0].spec.containers[0].env_from == env_from
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@@ -623,10 +624,9 @@ def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock):
monitor_mock.return_value = (State.SUCCESS, None)
context = create_context(k)
k.execute(context)
- self.assertEqual(
- start_mock.call_args[0][0].spec.containers[0].env_from,
- [k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=secret_ref))],
- )
+ assert start_mock.call_args[0][0].spec.containers[0].env_from == [
+ k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=secret_ref))
+ ]
def test_env_vars(self):
# WHEN
@@ -662,7 +662,7 @@ def test_env_vars(self):
{'name': 'ENV2', 'value': 'val2'},
{'name': 'ENV3', 'valueFrom': {'fieldRef': {'fieldPath': 'status.podIP'}}},
]
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_pod_template_file_system(self):
fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml'
@@ -675,8 +675,8 @@ def test_pod_template_file_system(self):
context = create_context(k)
result = k.execute(context)
- self.assertIsNotNone(result)
- self.assertDictEqual(result, {"hello": "world"})
+ assert result is not None
+ assert result == {"hello": "world"}
def test_pod_template_file_with_overrides_system(self):
fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml'
@@ -691,10 +691,10 @@ def test_pod_template_file_with_overrides_system(self):
context = create_context(k)
result = k.execute(context)
- self.assertIsNotNone(result)
- self.assertEqual(k.pod.metadata.labels, {'fizz': 'buzz', 'foo': 'bar'})
- self.assertEqual(k.pod.spec.containers[0].env, [k8s.V1EnvVar(name="env_name", value="value")])
- self.assertDictEqual(result, {"hello": "world"})
+ assert result is not None
+ assert k.pod.metadata.labels == {'fizz': 'buzz', 'foo': 'bar'}
+ assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")]
+ assert result == {"hello": "world"}
def test_pod_template_file_with_full_pod_spec(self):
fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml'
@@ -721,10 +721,10 @@ def test_pod_template_file_with_full_pod_spec(self):
context = create_context(k)
result = k.execute(context)
- self.assertIsNotNone(result)
- self.assertEqual(k.pod.metadata.labels, {'fizz': 'buzz', 'foo': 'bar'})
- self.assertEqual(k.pod.spec.containers[0].env, [k8s.V1EnvVar(name="env_name", value="value")])
- self.assertDictEqual(result, {"hello": "world"})
+ assert result is not None
+ assert k.pod.metadata.labels == {'fizz': 'buzz', 'foo': 'bar'}
+ assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")]
+ assert result == {"hello": "world"}
def test_full_pod_spec(self):
pod_spec = k8s.V1Pod(
@@ -753,10 +753,10 @@ def test_full_pod_spec(self):
context = create_context(k)
result = k.execute(context)
- self.assertIsNotNone(result)
- self.assertEqual(k.pod.metadata.labels, {'fizz': 'buzz', 'foo': 'bar'})
- self.assertEqual(k.pod.spec.containers[0].env, [k8s.V1EnvVar(name="env_name", value="value")])
- self.assertDictEqual(result, {"hello": "world"})
+ assert result is not None
+ assert k.pod.metadata.labels == {'fizz': 'buzz', 'foo': 'bar'}
+ assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")]
+ assert result == {"hello": "world"}
def test_init_container(self):
# GIVEN
@@ -811,7 +811,7 @@ def test_init_container(self):
self.expected_pod['spec']['volumes'] = [
{'name': 'test-volume', 'persistentVolumeClaim': {'claimName': 'test-volume'}}
]
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@@ -842,7 +842,7 @@ def test_pod_template_file(
deletion_grace_period_seconds: null\
"""
).strip()
- self.assertTrue(any(line.startswith(expected_line) for line in cm.output))
+ assert any(line.startswith(expected_line) for line in cm.output)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
expected_dict = {
@@ -882,7 +882,7 @@ def test_pod_template_file(
'volumes': [{'emptyDir': {}, 'name': 'xcom'}],
},
}
- self.assertEqual(expected_dict, actual_pod)
+ assert expected_dict == actual_pod
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@@ -912,11 +912,11 @@ def test_pod_priority_class_name(
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['priorityClassName'] = priority_class_name
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_pod_name(self):
pod_name_too_long = "a" * 221
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
KubernetesPodOperator(
namespace='default',
image="ubuntu:16.04",
@@ -953,9 +953,9 @@ def test_on_kill(self, monitor_mock): # pylint: disable=unused-argument
k.execute(context)
name = k.pod.metadata.name
pod = client.read_namespaced_pod(name=name, namespace=namespace)
- self.assertEqual(pod.status.phase, "Running")
+ assert pod.status.phase == "Running"
k.on_kill()
- with self.assertRaises(ApiException):
+ with pytest.raises(ApiException):
pod = client.read_namespaced_pod(name=name, namespace=namespace)
def test_reattach_failing_pod_once(self):
@@ -987,10 +987,10 @@ def test_reattach_failing_pod_once(self):
pod = client.read_namespaced_pod(name=name, namespace=namespace)
while pod.status.phase != "Failed":
pod = client.read_namespaced_pod(name=name, namespace=namespace)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
k.execute(context)
pod = client.read_namespaced_pod(name=name, namespace=namespace)
- self.assertEqual(pod.metadata.labels["already_checked"], "True")
+ assert pod.metadata.labels["already_checked"] == "True"
with mock.patch(
"airflow.providers.cncf.kubernetes"
".operators.kubernetes_pod.KubernetesPodOperator"
diff --git a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
index daefc153ada2c..88c7f3e01c04c 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
@@ -24,6 +24,7 @@
import kubernetes.client.models as k8s
import pendulum
+import pytest
from kubernetes.client.api_client import ApiClient
from kubernetes.client.rest import ApiException
@@ -147,10 +148,9 @@ def test_image_pull_secrets_correctly_set(self, mock_client, monitor_mock, start
monitor_mock.return_value = (State.SUCCESS, None)
context = self.create_context(k)
k.execute(context=context)
- self.assertEqual(
- start_mock.call_args[0][0].spec.image_pull_secrets,
- [k8s.V1LocalObjectReference(name=fake_pull_secrets)],
- )
+ assert start_mock.call_args[0][0].spec.image_pull_secrets == [
+ k8s.V1LocalObjectReference(name=fake_pull_secrets)
+ ]
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@@ -173,7 +173,7 @@ def test_pod_delete_even_on_launcher_error(
is_delete_operator_pod=True,
)
monitor_pod_mock.side_effect = AirflowException('fake failure')
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
context = self.create_context(k)
k.execute(context=context)
assert delete_pod_mock.called
@@ -193,8 +193,8 @@ def test_working_pod(self):
context = create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
- self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
- self.assertEqual(self.expected_pod['metadata']['labels'], actual_pod['metadata']['labels'])
+ assert self.expected_pod['spec'] == actual_pod['spec']
+ assert self.expected_pod['metadata']['labels'] == actual_pod['metadata']['labels']
def test_pod_node_selectors(self):
node_selectors = {'beta.kubernetes.io/os': 'linux'}
@@ -214,7 +214,7 @@ def test_pod_node_selectors(self):
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['nodeSelector'] = node_selectors
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_pod_resources(self):
resources = {
@@ -244,7 +244,7 @@ def test_pod_resources(self):
'requests': {'memory': '64Mi', 'cpu': '250m', 'ephemeral-storage': '1Gi'},
'limits': {'memory': '64Mi', 'cpu': 0.25, 'ephemeral-storage': '2Gi'},
}
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_pod_affinity(self):
affinity = {
@@ -276,7 +276,7 @@ def test_pod_affinity(self):
k.execute(context=context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['affinity'] = affinity
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_port(self):
port = Port('http', 80)
@@ -297,7 +297,7 @@ def test_port(self):
k.execute(context=context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['ports'] = [{'name': 'http', 'containerPort': 80}]
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_volume_mount(self):
with patch.object(PodLauncher, 'log') as mock_logger:
@@ -336,7 +336,7 @@ def test_volume_mount(self):
self.expected_pod['spec']['volumes'] = [
{'name': 'test-volume', 'persistentVolumeClaim': {'claimName': 'test-volume'}}
]
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_run_as_user_root(self):
security_context = {
@@ -360,7 +360,7 @@ def test_run_as_user_root(self):
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['securityContext'] = security_context
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_run_as_user_non_root(self):
security_context = {
@@ -385,7 +385,7 @@ def test_run_as_user_non_root(self):
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['securityContext'] = security_context
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_fs_group(self):
security_context = {
@@ -410,7 +410,7 @@ def test_fs_group(self):
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['securityContext'] = security_context
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_faulty_service_account(self):
bad_service_account_name = "foobar"
@@ -427,12 +427,12 @@ def test_faulty_service_account(self):
startup_timeout_seconds=5,
service_account_name=bad_service_account_name,
)
- with self.assertRaises(ApiException):
+ with pytest.raises(ApiException):
context = create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['serviceAccountName'] = bad_service_account_name
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_pod_failure(self):
"""
@@ -450,12 +450,12 @@ def test_pod_failure(self):
in_cluster=False,
do_xcom_push=False,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
context = create_context(k)
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['args'] = bad_internal_command
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_xcom_push(self):
return_value = '{"foo": "bar"\n, "buzz": 2}'
@@ -472,7 +472,7 @@ def test_xcom_push(self):
do_xcom_push=True,
)
context = create_context(k)
- self.assertEqual(k.execute(context), json.loads(return_value))
+ assert k.execute(context) == json.loads(return_value)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
volume = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME)
volume_mount = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME_MOUNT)
@@ -481,7 +481,7 @@ def test_xcom_push(self):
self.expected_pod['spec']['containers'][0]['volumeMounts'].insert(0, volume_mount) # noqa
self.expected_pod['spec']['volumes'].insert(0, volume)
self.expected_pod['spec']['containers'].append(container)
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@@ -506,10 +506,9 @@ def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start):
mock_monitor.return_value = (State.SUCCESS, None)
context = self.create_context(k)
k.execute(context)
- self.assertEqual(
- mock_start.call_args[0][0].spec.containers[0].env_from,
- [k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmap))],
- )
+ assert mock_start.call_args[0][0].spec.containers[0].env_from == [
+ k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmap))
+ ]
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@@ -535,10 +534,9 @@ def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock):
monitor_mock.return_value = (State.SUCCESS, None)
context = self.create_context(k)
k.execute(context)
- self.assertEqual(
- start_mock.call_args[0][0].spec.containers[0].env_from,
- [k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=secret_ref))],
- )
+ assert start_mock.call_args[0][0].spec.containers[0].env_from == [
+ k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=secret_ref))
+ ]
def test_env_vars(self):
# WHEN
@@ -569,7 +567,7 @@ def test_env_vars(self):
{'name': 'ENV2', 'value': 'val2'},
{'name': 'ENV3', 'valueFrom': {'fieldRef': {'fieldPath': 'status.podIP'}}},
]
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_pod_template_file_with_overrides_system(self):
fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml'
@@ -584,10 +582,10 @@ def test_pod_template_file_with_overrides_system(self):
context = create_context(k)
result = k.execute(context)
- self.assertIsNotNone(result)
- self.assertEqual(k.pod.metadata.labels, {'fizz': 'buzz', 'foo': 'bar'})
- self.assertEqual(k.pod.spec.containers[0].env, [k8s.V1EnvVar(name="env_name", value="value")])
- self.assertDictEqual(result, {"hello": "world"})
+ assert result is not None
+ assert k.pod.metadata.labels == {'fizz': 'buzz', 'foo': 'bar'}
+ assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")]
+ assert result == {"hello": "world"}
def test_init_container(self):
# GIVEN
@@ -641,7 +639,7 @@ def test_init_container(self):
self.expected_pod['spec']['volumes'] = [
{'name': 'test-volume', 'persistentVolumeClaim': {'claimName': 'test-volume'}}
]
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@@ -669,11 +667,11 @@ def test_pod_priority_class_name(
k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['priorityClassName'] = priority_class_name
- self.assertEqual(self.expected_pod, actual_pod)
+ assert self.expected_pod == actual_pod
def test_pod_name(self):
pod_name_too_long = "a" * 221
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
KubernetesPodOperator(
namespace='default',
image="ubuntu:16.04",
diff --git a/pylintrc b/pylintrc
index ae2e34ad8cad2..cd9a06023bb56 100644
--- a/pylintrc
+++ b/pylintrc
@@ -154,7 +154,8 @@ disable=print-statement,
ungrouped-imports, # Disabled to avoid conflict with isort import order rules, which is enabled in the project.
missing-module-docstring,
import-outside-toplevel, # We import outside toplevel to avoid cyclic imports
- raise-missing-from # We don't use raise...from
+ raise-missing-from, # We don't use raise...from
+ misplaced-comparison-constant
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
diff --git a/tests/always/test_example_dags.py b/tests/always/test_example_dags.py
index 339a42a4c2f47..a21056dcfc087 100644
--- a/tests/always/test_example_dags.py
+++ b/tests/always/test_example_dags.py
@@ -32,7 +32,7 @@
class TestExampleDags(unittest.TestCase):
def test_should_be_importable(self):
example_dags = list(glob(f"{ROOT_FOLDER}/airflow/**/example_dags/example_*.py", recursive=True))
- self.assertNotEqual(0, len(example_dags))
+ assert 0 != len(example_dags)
for filepath in example_dags:
relative_filepath = os.path.relpath(filepath, ROOT_FOLDER)
with self.subTest(f"File {relative_filepath} should contain dags"):
@@ -40,8 +40,8 @@ def test_should_be_importable(self):
dag_folder=filepath,
include_examples=False,
)
- self.assertEqual(0, len(dagbag.import_errors), f"import_errors={str(dagbag.import_errors)}")
- self.assertGreaterEqual(len(dagbag.dag_ids), 1)
+ assert 0 == len(dagbag.import_errors), f"import_errors={str(dagbag.import_errors)}"
+ assert len(dagbag.dag_ids) >= 1
def test_should_not_do_database_queries(self):
example_dags = glob(f"{ROOT_FOLDER}/airflow/**/example_dags/example_*.py", recursive=True)
@@ -50,7 +50,7 @@ def test_should_not_do_database_queries(self):
for dag_file in example_dags
if any(not dag_file.endswith(e) for e in NO_DB_QUERY_EXCEPTION)
]
- self.assertNotEqual(0, len(example_dags))
+ assert 0 != len(example_dags)
for filepath in example_dags:
relative_filepath = os.path.relpath(filepath, ROOT_FOLDER)
with self.subTest(f"File {relative_filepath} shouldn't do database queries"):
diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py
index d111eb3e6e953..66fafc9106a5d 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -100,7 +100,7 @@ def test_providers_modules_should_have_tests(self):
with self.subTest("Detect missing tests in providers module"):
expected_missing_test_modules = {pair[1] for pair in expected_missing_providers_modules}
missing_tests_files = missing_tests_files - set(expected_missing_test_modules)
- self.assertEqual(set(), missing_tests_files)
+ assert set() == missing_tests_files
with self.subTest("Verify removed deprecated module also removed from deprecated list"):
expected_missing_modules = {pair[0] for pair in expected_missing_providers_modules}
@@ -247,7 +247,7 @@ def has_example_dag(operator_set):
with self.subTest("Detect missing example dags"):
missing_example = {s for s in operator_sets if not has_example_dag(s)}
missing_example -= self.MISSING_EXAMPLE_DAGS
- self.assertEqual(set(), missing_example)
+ assert set() == missing_example
with self.subTest("Keep update missing example dags list"):
new_example_dag = set(example_sets).intersection(set(self.MISSING_EXAMPLE_DAGS))
@@ -299,7 +299,7 @@ def test_missing_example_for_operator(self):
print("example_paths=", example_paths)
operators_paths = set(get_classes_from_file(f"{ROOT_FOLDER}/{filepath}"))
missing_operators.extend(operators_paths - example_paths)
- self.assertEqual(set(missing_operators), self.MISSING_EXAMPLES_FOR_OPERATORS)
+ assert set(missing_operators) == self.MISSING_EXAMPLES_FOR_OPERATORS
@parameterized.expand(
itertools.product(["_system.py", "_system_helper.py"], ["operators", "sensors", "transfers"])
@@ -314,7 +314,7 @@ def test_detect_invalid_system_tests(self, resource_type, filename_suffix):
expected_files = (f.replace(".py", filename_suffix).replace("/test_", "/") for f in expected_files)
expected_files = {f'{f.rpartition("/")[0]}/test_{f.rpartition("/")[2]}' for f in expected_files}
- self.assertEqual(set(), files - expected_files)
+ assert set() == files - expected_files
@staticmethod
def find_resource_files(
@@ -346,4 +346,4 @@ def test_no_illegal_suffixes(self):
invalid_files = [f for f in files if any(f.endswith(suffix) for suffix in illegal_suffixes)]
- self.assertEqual([], invalid_files)
+ assert [] == invalid_files
diff --git a/tests/api/auth/backend/test_kerberos_auth.py b/tests/api/auth/backend/test_kerberos_auth.py
index ef31679fd28b6..90113adbf1aba 100644
--- a/tests/api/auth/backend/test_kerberos_auth.py
+++ b/tests/api/auth/backend/test_kerberos_auth.py
@@ -64,7 +64,7 @@ def test_trigger_dag(self):
data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())),
content_type="application/json",
)
- self.assertEqual(401, response.status_code)
+ assert 401 == response.status_code
response.url = f'http://{socket.getfqdn()}'
@@ -81,7 +81,7 @@ class Request:
CLIENT_AUTH.mutual_authentication = 3
CLIENT_AUTH.handle_response(response)
- self.assertIn('Authorization', response.request.headers)
+ assert 'Authorization' in response.request.headers
response2 = client.post(
url_template.format('example_bash_operator'),
@@ -89,7 +89,7 @@ class Request:
content_type="application/json",
headers=response.request.headers,
)
- self.assertEqual(200, response2.status_code)
+ assert 200 == response2.status_code
def test_unauthorized(self):
with self.app.test_client() as client:
@@ -100,4 +100,4 @@ def test_unauthorized(self):
content_type="application/json",
)
- self.assertEqual(401, response.status_code)
+ assert 401 == response.status_code
diff --git a/tests/api/auth/test_client.py b/tests/api/auth/test_client.py
index 8652b12772b5f..cf3c7b5fdba18 100644
--- a/tests/api/auth/test_client.py
+++ b/tests/api/auth/test_client.py
@@ -38,7 +38,7 @@ def test_should_create_client(self, mock_client):
mock_client.assert_called_once_with(
api_base_url='http://localhost:1234', auth='CLIENT_AUTH', session=None
)
- self.assertEqual(mock_client.return_value, result)
+ assert mock_client.return_value == result
@mock.patch("airflow.api.client.json_client.Client")
@mock.patch("airflow.providers.google.common.auth_backend.google_openid.create_client_session")
@@ -55,4 +55,4 @@ def test_should_create_google_open_id_client(self, mock_create_client_session, m
mock_client.assert_called_once_with(
api_base_url='http://localhost:1234', auth=None, session=mock_create_client_session.return_value
)
- self.assertEqual(mock_client.return_value, result)
+ assert mock_client.return_value == result
diff --git a/tests/api/client/test_local_client.py b/tests/api/client/test_local_client.py
index d574615a243cf..c20a1b98112ed 100644
--- a/tests/api/client/test_local_client.py
+++ b/tests/api/client/test_local_client.py
@@ -20,6 +20,7 @@
import unittest
from unittest.mock import ANY, patch
+import pytest
from freezegun import freeze_time
from airflow.api.client.local_client import Client
@@ -60,7 +61,7 @@ def test_trigger_dag(self, mock):
DagBag(include_examples=True)
# non existent
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.client.trigger_dag(dag_id="blablabla")
with freeze_time(EXECDATE):
@@ -118,36 +119,36 @@ def test_delete_dag(self):
key = "my_dag_id"
with create_session() as session:
- self.assertEqual(session.query(DagModel).filter(DagModel.dag_id == key).count(), 0)
+ assert session.query(DagModel).filter(DagModel.dag_id == key).count() == 0
session.add(DagModel(dag_id=key))
with create_session() as session:
- self.assertEqual(session.query(DagModel).filter(DagModel.dag_id == key).count(), 1)
+ assert session.query(DagModel).filter(DagModel.dag_id == key).count() == 1
self.client.delete_dag(dag_id=key)
- self.assertEqual(session.query(DagModel).filter(DagModel.dag_id == key).count(), 0)
+ assert session.query(DagModel).filter(DagModel.dag_id == key).count() == 0
def test_get_pool(self):
self.client.create_pool(name='foo', slots=1, description='')
pool = self.client.get_pool(name='foo')
- self.assertEqual(pool, ('foo', 1, ''))
+ assert pool == ('foo', 1, '')
def test_get_pools(self):
self.client.create_pool(name='foo1', slots=1, description='')
self.client.create_pool(name='foo2', slots=2, description='')
pools = sorted(self.client.get_pools(), key=lambda p: p[0])
- self.assertEqual(pools, [('default_pool', 128, 'Default pool'), ('foo1', 1, ''), ('foo2', 2, '')])
+ assert pools == [('default_pool', 128, 'Default pool'), ('foo1', 1, ''), ('foo2', 2, '')]
def test_create_pool(self):
pool = self.client.create_pool(name='foo', slots=1, description='')
- self.assertEqual(pool, ('foo', 1, ''))
+ assert pool == ('foo', 1, '')
with create_session() as session:
- self.assertEqual(session.query(Pool).count(), 2)
+ assert session.query(Pool).count() == 2
def test_delete_pool(self):
self.client.create_pool(name='foo', slots=1, description='')
with create_session() as session:
- self.assertEqual(session.query(Pool).count(), 2)
+ assert session.query(Pool).count() == 2
self.client.delete_pool(name='foo')
with create_session() as session:
- self.assertEqual(session.query(Pool).count(), 1)
+ assert session.query(Pool).count() == 1
diff --git a/tests/api/common/experimental/test_delete_dag.py b/tests/api/common/experimental/test_delete_dag.py
index 4d5b47fabaa99..7570cb8a7a2a6 100644
--- a/tests/api/common/experimental/test_delete_dag.py
+++ b/tests/api/common/experimental/test_delete_dag.py
@@ -18,6 +18,8 @@
import unittest
+import pytest
+
from airflow import models
from airflow.api.common.experimental.delete_dag import delete_dag
from airflow.exceptions import DagNotFound
@@ -46,7 +48,7 @@ def tearDown(self):
self.dag.clear()
def test_delete_dag_non_existent_dag(self):
- with self.assertRaises(DagNotFound):
+ with pytest.raises(DagNotFound):
delete_dag("non-existent DAG")
@@ -112,23 +114,23 @@ def tearDown(self):
def check_dag_models_exists(self):
with create_session() as session:
- self.assertEqual(session.query(DM).filter(DM.dag_id == self.key).count(), 1)
- self.assertEqual(session.query(DR).filter(DR.dag_id == self.key).count(), 1)
- self.assertEqual(session.query(TI).filter(TI.dag_id == self.key).count(), 1)
- self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 1)
- self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 1)
- self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 1)
- self.assertEqual(session.query(IE).filter(IE.filename == self.dag_file_path).count(), 1)
+ assert session.query(DM).filter(DM.dag_id == self.key).count() == 1
+ assert session.query(DR).filter(DR.dag_id == self.key).count() == 1
+ assert session.query(TI).filter(TI.dag_id == self.key).count() == 1
+ assert session.query(TF).filter(TF.dag_id == self.key).count() == 1
+ assert session.query(TR).filter(TR.dag_id == self.key).count() == 1
+ assert session.query(LOG).filter(LOG.dag_id == self.key).count() == 1
+ assert session.query(IE).filter(IE.filename == self.dag_file_path).count() == 1
def check_dag_models_removed(self, expect_logs=1):
with create_session() as session:
- self.assertEqual(session.query(DM).filter(DM.dag_id == self.key).count(), 0)
- self.assertEqual(session.query(DR).filter(DR.dag_id == self.key).count(), 0)
- self.assertEqual(session.query(TI).filter(TI.dag_id == self.key).count(), 0)
- self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 0)
- self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 0)
- self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), expect_logs)
- self.assertEqual(session.query(IE).filter(IE.filename == self.dag_file_path).count(), 0)
+ assert session.query(DM).filter(DM.dag_id == self.key).count() == 0
+ assert session.query(DR).filter(DR.dag_id == self.key).count() == 0
+ assert session.query(TI).filter(TI.dag_id == self.key).count() == 0
+ assert session.query(TF).filter(TF.dag_id == self.key).count() == 0
+ assert session.query(TR).filter(TR.dag_id == self.key).count() == 0
+ assert session.query(LOG).filter(LOG.dag_id == self.key).count() == expect_logs
+ assert session.query(IE).filter(IE.filename == self.dag_file_path).count() == 0
def test_delete_dag_successful_delete(self):
self.setup_dag_models()
diff --git a/tests/api/common/experimental/test_mark_tasks.py b/tests/api/common/experimental/test_mark_tasks.py
index 4ff093100c93f..a9f1c8641fb90 100644
--- a/tests/api/common/experimental/test_mark_tasks.py
+++ b/tests/api/common/experimental/test_mark_tasks.py
@@ -104,18 +104,18 @@ def verify_state(self, dag, task_ids, execution_dates, state, old_tis, session=N
tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates)).all()
- self.assertTrue(len(tis) > 0)
+ assert len(tis) > 0
for ti in tis: # pylint: disable=too-many-nested-blocks
- self.assertEqual(ti.operator, dag.get_task(ti.task_id).task_type)
+ assert ti.operator == dag.get_task(ti.task_id).task_type
if ti.task_id in task_ids and ti.execution_date in execution_dates:
- self.assertEqual(ti.state, state)
+ assert ti.state == state
if state in State.finished:
- self.assertIsNotNone(ti.end_date)
+ assert ti.end_date is not None
else:
for old_ti in old_tis:
if old_ti.task_id == ti.task_id and old_ti.execution_date == ti.execution_date:
- self.assertEqual(ti.state, old_ti.state)
+ assert ti.state == old_ti.state
def test_mark_tasks_now(self):
# set one task to success but do not commit
@@ -131,7 +131,7 @@ def test_mark_tasks_now(self):
state=State.SUCCESS,
commit=False,
)
- self.assertEqual(len(altered), 1)
+ assert len(altered) == 1
self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], None, snapshot)
# set one and only one task to success
@@ -145,7 +145,7 @@ def test_mark_tasks_now(self):
state=State.SUCCESS,
commit=True,
)
- self.assertEqual(len(altered), 1)
+ assert len(altered) == 1
self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot)
# set no tasks
@@ -159,7 +159,7 @@ def test_mark_tasks_now(self):
state=State.SUCCESS,
commit=True,
)
- self.assertEqual(len(altered), 0)
+ assert len(altered) == 0
self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot)
# set task to other than success
@@ -173,7 +173,7 @@ def test_mark_tasks_now(self):
state=State.FAILED,
commit=True,
)
- self.assertEqual(len(altered), 1)
+ assert len(altered) == 1
self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.FAILED, snapshot)
# don't alter other tasks
@@ -189,7 +189,7 @@ def test_mark_tasks_now(self):
state=State.SUCCESS,
commit=True,
)
- self.assertEqual(len(altered), 1)
+ assert len(altered) == 1
self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot)
# set one task as FAILED. dag3 has schedule_interval None
@@ -206,7 +206,7 @@ def test_mark_tasks_now(self):
commit=True,
)
# exactly one TaskInstance should have been altered
- self.assertEqual(len(altered), 1)
+ assert len(altered) == 1
# task should have been marked as failed
self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[1]], State.FAILED, snapshot)
# tasks on other days should be unchanged
@@ -231,7 +231,7 @@ def test_mark_downstream(self):
state=State.SUCCESS,
commit=True,
)
- self.assertEqual(len(altered), 3)
+ assert len(altered) == 3
self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], State.SUCCESS, snapshot)
def test_mark_upstream(self):
@@ -252,7 +252,7 @@ def test_mark_upstream(self):
state=State.SUCCESS,
commit=True,
)
- self.assertEqual(len(altered), 4)
+ assert len(altered) == 4
self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], State.SUCCESS, snapshot)
def test_mark_tasks_future(self):
@@ -269,7 +269,7 @@ def test_mark_tasks_future(self):
state=State.SUCCESS,
commit=True,
)
- self.assertEqual(len(altered), 2)
+ assert len(altered) == 2
self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot)
snapshot = TestMarkTasks.snapshot_state(self.dag3, self.dag3_execution_dates)
@@ -284,7 +284,7 @@ def test_mark_tasks_future(self):
state=State.FAILED,
commit=True,
)
- self.assertEqual(len(altered), 2)
+ assert len(altered) == 2
self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[0]], None, snapshot)
self.verify_state(self.dag3, [task.task_id], self.dag3_execution_dates[1:], State.FAILED, snapshot)
@@ -302,7 +302,7 @@ def test_mark_tasks_past(self):
state=State.SUCCESS,
commit=True,
)
- self.assertEqual(len(altered), 2)
+ assert len(altered) == 2
self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot)
snapshot = TestMarkTasks.snapshot_state(self.dag3, self.dag3_execution_dates)
@@ -317,7 +317,7 @@ def test_mark_tasks_past(self):
state=State.FAILED,
commit=True,
)
- self.assertEqual(len(altered), 2)
+ assert len(altered) == 2
self.verify_state(self.dag3, [task.task_id], self.dag3_execution_dates[:2], State.FAILED, snapshot)
self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[2]], None, snapshot)
@@ -335,7 +335,7 @@ def test_mark_tasks_multiple(self):
state=State.SUCCESS,
commit=True,
)
- self.assertEqual(len(altered), 2)
+ assert len(altered) == 2
self.verify_state(
self.dag1, [task.task_id for task in tasks], [self.execution_dates[0]], State.SUCCESS, snapshot
)
@@ -361,7 +361,7 @@ def test_mark_tasks_subdag(self):
state=State.SUCCESS,
commit=True,
)
- self.assertEqual(len(altered), 14)
+ assert len(altered) == 14
# cannot use snapshot here as that will require drilling down the
# sub dag tree essentially recreating the same code as in the
@@ -397,19 +397,19 @@ def _set_default_task_instance_states(self, dr):
dr.get_task_instance('run_this_last').set_state(State.FAILED)
def _verify_task_instance_states_remain_default(self, dr):
- self.assertEqual(dr.get_task_instance('runme_0').state, State.SUCCESS)
- self.assertEqual(dr.get_task_instance('runme_1').state, State.SKIPPED)
- self.assertEqual(dr.get_task_instance('runme_2').state, State.UP_FOR_RETRY)
- self.assertEqual(dr.get_task_instance('also_run_this').state, State.QUEUED)
- self.assertEqual(dr.get_task_instance('run_after_loop').state, State.RUNNING)
- self.assertEqual(dr.get_task_instance('run_this_last').state, State.FAILED)
+ assert dr.get_task_instance('runme_0').state == State.SUCCESS
+ assert dr.get_task_instance('runme_1').state == State.SKIPPED
+ assert dr.get_task_instance('runme_2').state == State.UP_FOR_RETRY
+ assert dr.get_task_instance('also_run_this').state == State.QUEUED
+ assert dr.get_task_instance('run_after_loop').state == State.RUNNING
+ assert dr.get_task_instance('run_this_last').state == State.FAILED
@provide_session
def _verify_task_instance_states(self, dag, date, state, session=None):
TI = models.TaskInstance
tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date == date)
for ti in tis:
- self.assertEqual(ti.state, state)
+ assert ti.state == state
def _create_test_dag_run(self, state, date):
return self.dag1.create_dagrun(run_type=DagRunType.MANUAL, state=state, execution_date=date)
@@ -418,7 +418,7 @@ def _verify_dag_run_state(self, dag, date, state):
drs = models.DagRun.find(dag_id=dag.dag_id, execution_date=date)
dr = drs[0]
- self.assertEqual(dr.get_state(), state)
+ assert dr.get_state() == state
@provide_session
def _verify_dag_run_dates(self, dag, date, state, middle_time, session=None):
@@ -428,13 +428,13 @@ def _verify_dag_run_dates(self, dag, date, state, middle_time, session=None):
dr = session.query(DR).filter(DR.dag_id == dag.dag_id, DR.execution_date == date).one()
if state == State.RUNNING:
# Since the DAG is running, the start_date must be updated after creation
- self.assertGreater(dr.start_date, middle_time)
+ assert dr.start_date > middle_time
# If the dag is still running, we don't have an end date
- self.assertIsNone(dr.end_date)
+ assert dr.end_date is None
else:
# If the dag is not running, there must be an end time
- self.assertLess(dr.start_date, middle_time)
- self.assertGreater(dr.end_date, middle_time)
+ assert dr.start_date < middle_time
+ assert dr.end_date > middle_time
def test_set_running_dag_run_to_success(self):
date = self.execution_dates[0]
@@ -445,7 +445,7 @@ def test_set_running_dag_run_to_success(self):
altered = set_dag_run_state_to_success(self.dag1, date, commit=True)
# All except the SUCCESS task should be altered.
- self.assertEqual(len(altered), 5)
+ assert len(altered) == 5
self._verify_dag_run_state(self.dag1, date, State.SUCCESS)
self._verify_task_instance_states(self.dag1, date, State.SUCCESS)
self._verify_dag_run_dates(self.dag1, date, State.SUCCESS, middle_time)
@@ -459,9 +459,9 @@ def test_set_running_dag_run_to_failed(self):
altered = set_dag_run_state_to_failed(self.dag1, date, commit=True)
# Only running task should be altered.
- self.assertEqual(len(altered), 1)
+ assert len(altered) == 1
self._verify_dag_run_state(self.dag1, date, State.FAILED)
- self.assertEqual(dr.get_task_instance('run_after_loop').state, State.FAILED)
+ assert dr.get_task_instance('run_after_loop').state == State.FAILED
self._verify_dag_run_dates(self.dag1, date, State.FAILED, middle_time)
def test_set_running_dag_run_to_running(self):
@@ -473,7 +473,7 @@ def test_set_running_dag_run_to_running(self):
altered = set_dag_run_state_to_running(self.dag1, date, commit=True)
# None of the tasks should be altered, only the dag itself
- self.assertEqual(len(altered), 0)
+ assert len(altered) == 0
self._verify_dag_run_state(self.dag1, date, State.RUNNING)
self._verify_task_instance_states_remain_default(dr)
self._verify_dag_run_dates(self.dag1, date, State.RUNNING, middle_time)
@@ -487,7 +487,7 @@ def test_set_success_dag_run_to_success(self):
altered = set_dag_run_state_to_success(self.dag1, date, commit=True)
# All except the SUCCESS task should be altered.
- self.assertEqual(len(altered), 5)
+ assert len(altered) == 5
self._verify_dag_run_state(self.dag1, date, State.SUCCESS)
self._verify_task_instance_states(self.dag1, date, State.SUCCESS)
self._verify_dag_run_dates(self.dag1, date, State.SUCCESS, middle_time)
@@ -501,9 +501,9 @@ def test_set_success_dag_run_to_failed(self):
altered = set_dag_run_state_to_failed(self.dag1, date, commit=True)
# Only running task should be altered.
- self.assertEqual(len(altered), 1)
+ assert len(altered) == 1
self._verify_dag_run_state(self.dag1, date, State.FAILED)
- self.assertEqual(dr.get_task_instance('run_after_loop').state, State.FAILED)
+ assert dr.get_task_instance('run_after_loop').state == State.FAILED
self._verify_dag_run_dates(self.dag1, date, State.FAILED, middle_time)
def test_set_success_dag_run_to_running(self):
@@ -515,7 +515,7 @@ def test_set_success_dag_run_to_running(self):
altered = set_dag_run_state_to_running(self.dag1, date, commit=True)
# None of the tasks should be altered, but only the dag object should be changed
- self.assertEqual(len(altered), 0)
+ assert len(altered) == 0
self._verify_dag_run_state(self.dag1, date, State.RUNNING)
self._verify_task_instance_states_remain_default(dr)
self._verify_dag_run_dates(self.dag1, date, State.RUNNING, middle_time)
@@ -529,7 +529,7 @@ def test_set_failed_dag_run_to_success(self):
altered = set_dag_run_state_to_success(self.dag1, date, commit=True)
# All except the SUCCESS task should be altered.
- self.assertEqual(len(altered), 5)
+ assert len(altered) == 5
self._verify_dag_run_state(self.dag1, date, State.SUCCESS)
self._verify_task_instance_states(self.dag1, date, State.SUCCESS)
self._verify_dag_run_dates(self.dag1, date, State.SUCCESS, middle_time)
@@ -543,9 +543,9 @@ def test_set_failed_dag_run_to_failed(self):
altered = set_dag_run_state_to_failed(self.dag1, date, commit=True)
# Only running task should be altered.
- self.assertEqual(len(altered), 1)
+ assert len(altered) == 1
self._verify_dag_run_state(self.dag1, date, State.FAILED)
- self.assertEqual(dr.get_task_instance('run_after_loop').state, State.FAILED)
+ assert dr.get_task_instance('run_after_loop').state == State.FAILED
self._verify_dag_run_dates(self.dag1, date, State.FAILED, middle_time)
def test_set_failed_dag_run_to_running(self):
@@ -559,7 +559,7 @@ def test_set_failed_dag_run_to_running(self):
altered = set_dag_run_state_to_running(self.dag1, date, commit=True)
# None of the tasks should be altered, since we've only altered the DAG itself
- self.assertEqual(len(altered), 0)
+ assert len(altered) == 0
self._verify_dag_run_state(self.dag1, date, State.RUNNING)
self._verify_task_instance_states_remain_default(dr)
self._verify_dag_run_dates(self.dag1, date, State.RUNNING, middle_time)
@@ -572,21 +572,21 @@ def test_set_state_without_commit(self):
will_be_altered = set_dag_run_state_to_running(self.dag1, date, commit=False)
# None of the tasks will be altered.
- self.assertEqual(len(will_be_altered), 0)
+ assert len(will_be_altered) == 0
self._verify_dag_run_state(self.dag1, date, State.RUNNING)
self._verify_task_instance_states_remain_default(dr)
will_be_altered = set_dag_run_state_to_failed(self.dag1, date, commit=False)
# Only the running task will be altered.
- self.assertEqual(len(will_be_altered), 1)
+ assert len(will_be_altered) == 1
self._verify_dag_run_state(self.dag1, date, State.RUNNING)
self._verify_task_instance_states_remain_default(dr)
will_be_altered = set_dag_run_state_to_success(self.dag1, date, commit=False)
# All except the SUCCESS task should be altered.
- self.assertEqual(len(will_be_altered), 5)
+ assert len(will_be_altered) == 5
self._verify_dag_run_state(self.dag1, date, State.RUNNING)
self._verify_task_instance_states_remain_default(dr)
@@ -620,7 +620,7 @@ def count_dag_tasks(dag):
count += sum(subdag_counts)
return count
- self.assertEqual(len(altered), count_dag_tasks(self.dag2))
+ assert len(altered) == count_dag_tasks(self.dag2)
self._verify_dag_run_state(self.dag2, self.execution_dates[1], State.SUCCESS)
# Make sure other dag status are not changed
@@ -632,29 +632,29 @@ def count_dag_tasks(dag):
def test_set_dag_run_state_edge_cases(self):
# Dag does not exist
altered = set_dag_run_state_to_success(None, self.execution_dates[0])
- self.assertEqual(len(altered), 0)
+ assert len(altered) == 0
altered = set_dag_run_state_to_failed(None, self.execution_dates[0])
- self.assertEqual(len(altered), 0)
+ assert len(altered) == 0
altered = set_dag_run_state_to_running(None, self.execution_dates[0])
- self.assertEqual(len(altered), 0)
+ assert len(altered) == 0
# Invalid execution date
altered = set_dag_run_state_to_success(self.dag1, None)
- self.assertEqual(len(altered), 0)
+ assert len(altered) == 0
altered = set_dag_run_state_to_failed(self.dag1, None)
- self.assertEqual(len(altered), 0)
+ assert len(altered) == 0
altered = set_dag_run_state_to_running(self.dag1, None)
- self.assertEqual(len(altered), 0)
+ assert len(altered) == 0
# This will throw ValueError since dag.latest_execution_date
# need to be 0 does not exist.
- self.assertRaises(
- ValueError, set_dag_run_state_to_success, self.dag2, timezone.make_naive(self.execution_dates[0])
- )
+ with pytest.raises(ValueError):
+ set_dag_run_state_to_success(self.dag2, timezone.make_naive(self.execution_dates[0]))
# altered = set_dag_run_state_to_success(self.dag1, self.execution_dates[0])
# DagRun does not exist
# This will throw ValueError since dag.latest_execution_date does not exist
- self.assertRaises(ValueError, set_dag_run_state_to_success, self.dag2, self.execution_dates[0])
+ with pytest.raises(ValueError):
+ set_dag_run_state_to_success(self.dag2, self.execution_dates[0])
def test_set_dag_run_state_to_failed_no_running_tasks(self):
"""
diff --git a/tests/api/common/experimental/test_pool.py b/tests/api/common/experimental/test_pool.py
index 433823978a39b..ae0022663e20f 100644
--- a/tests/api/common/experimental/test_pool.py
+++ b/tests/api/common/experimental/test_pool.py
@@ -20,6 +20,8 @@
import string
import unittest
+import pytest
+
from airflow import models
from airflow.api.common.experimental import pool as pool_api
from airflow.exceptions import AirflowBadRequest, PoolNotFound
@@ -49,88 +51,82 @@ def setUp(self):
def test_get_pool(self):
pool = pool_api.get_pool(name=self.pools[0].pool)
- self.assertEqual(pool.pool, self.pools[0].pool)
+ assert pool.pool == self.pools[0].pool
def test_get_pool_non_existing(self):
- self.assertRaisesRegex(PoolNotFound, "^Pool 'test' doesn't exist$", pool_api.get_pool, name='test')
+ with pytest.raises(PoolNotFound, match="^Pool 'test' doesn't exist$"):
+ pool_api.get_pool(name='test')
def test_get_pool_bad_name(self):
for name in ('', ' '):
- self.assertRaisesRegex(
- AirflowBadRequest, "^Pool name shouldn't be empty$", pool_api.get_pool, name=name
- )
+ with pytest.raises(AirflowBadRequest, match="^Pool name shouldn't be empty$"):
+ pool_api.get_pool(name=name)
def test_get_pools(self):
pools = sorted(pool_api.get_pools(), key=lambda p: p.pool)
- self.assertEqual(pools[0].pool, self.pools[0].pool)
- self.assertEqual(pools[1].pool, self.pools[1].pool)
+ assert pools[0].pool == self.pools[0].pool
+ assert pools[1].pool == self.pools[1].pool
def test_create_pool(self):
pool = pool_api.create_pool(name='foo', slots=5, description='')
- self.assertEqual(pool.pool, 'foo')
- self.assertEqual(pool.slots, 5)
- self.assertEqual(pool.description, '')
+ assert pool.pool == 'foo'
+ assert pool.slots == 5
+ assert pool.description == ''
with create_session() as session:
- self.assertEqual(session.query(models.Pool).count(), self.TOTAL_POOL_COUNT + 1)
+ assert session.query(models.Pool).count() == self.TOTAL_POOL_COUNT + 1
def test_create_pool_existing(self):
pool = pool_api.create_pool(name=self.pools[0].pool, slots=5, description='')
- self.assertEqual(pool.pool, self.pools[0].pool)
- self.assertEqual(pool.slots, 5)
- self.assertEqual(pool.description, '')
+ assert pool.pool == self.pools[0].pool
+ assert pool.slots == 5
+ assert pool.description == ''
with create_session() as session:
- self.assertEqual(session.query(models.Pool).count(), self.TOTAL_POOL_COUNT)
+ assert session.query(models.Pool).count() == self.TOTAL_POOL_COUNT
def test_create_pool_bad_name(self):
for name in ('', ' '):
- self.assertRaisesRegex(
- AirflowBadRequest,
- "^Pool name shouldn't be empty$",
- pool_api.create_pool,
- name=name,
- slots=5,
- description='',
- )
+ with pytest.raises(AirflowBadRequest, match="^Pool name shouldn't be empty$"):
+ pool_api.create_pool(
+ name=name,
+ slots=5,
+ description='',
+ )
def test_create_pool_name_too_long(self):
long_name = ''.join(random.choices(string.ascii_lowercase, k=300))
column_length = models.Pool.pool.property.columns[0].type.length
- self.assertRaisesRegex(
- AirflowBadRequest,
- "^Pool name can't be more than %d characters$" % column_length,
- pool_api.create_pool,
- name=long_name,
- slots=5,
- description='',
- )
+ with pytest.raises(
+ AirflowBadRequest, match="^Pool name can't be more than %d characters$" % column_length
+ ):
+ pool_api.create_pool(
+ name=long_name,
+ slots=5,
+ description='',
+ )
def test_create_pool_bad_slots(self):
- self.assertRaisesRegex(
- AirflowBadRequest,
- "^Bad value for `slots`: foo$",
- pool_api.create_pool,
- name='foo',
- slots='foo',
- description='',
- )
+ with pytest.raises(AirflowBadRequest, match="^Bad value for `slots`: foo$"):
+ pool_api.create_pool(
+ name='foo',
+ slots='foo',
+ description='',
+ )
def test_delete_pool(self):
pool = pool_api.delete_pool(name=self.pools[-1].pool)
- self.assertEqual(pool.pool, self.pools[-1].pool)
+ assert pool.pool == self.pools[-1].pool
with create_session() as session:
- self.assertEqual(session.query(models.Pool).count(), self.TOTAL_POOL_COUNT - 1)
+ assert session.query(models.Pool).count() == self.TOTAL_POOL_COUNT - 1
def test_delete_pool_non_existing(self):
- self.assertRaisesRegex(
- pool_api.PoolNotFound, "^Pool 'test' doesn't exist$", pool_api.delete_pool, name='test'
- )
+ with pytest.raises(pool_api.PoolNotFound, match="^Pool 'test' doesn't exist$"):
+ pool_api.delete_pool(name='test')
def test_delete_pool_bad_name(self):
for name in ('', ' '):
- self.assertRaisesRegex(
- AirflowBadRequest, "^Pool name shouldn't be empty$", pool_api.delete_pool, name=name
- )
+ with pytest.raises(AirflowBadRequest, match="^Pool name shouldn't be empty$"):
+ pool_api.delete_pool(name=name)
def test_delete_default_pool_not_allowed(self):
- with self.assertRaisesRegex(AirflowBadRequest, "^default_pool cannot be deleted$"):
+ with pytest.raises(AirflowBadRequest, match="^default_pool cannot be deleted$"):
pool_api.delete_pool(Pool.DEFAULT_POOL_NAME)
diff --git a/tests/api/common/experimental/test_trigger_dag.py b/tests/api/common/experimental/test_trigger_dag.py
index 9fb772d3576f9..cbca935438438 100644
--- a/tests/api/common/experimental/test_trigger_dag.py
+++ b/tests/api/common/experimental/test_trigger_dag.py
@@ -19,6 +19,7 @@
import unittest
from unittest import mock
+import pytest
from parameterized import parameterized
from airflow.api.common.experimental.trigger_dag import _trigger_dag
@@ -38,7 +39,7 @@ def tearDown(self) -> None:
@mock.patch('airflow.models.DagBag')
def test_trigger_dag_dag_not_found(self, dag_bag_mock):
dag_bag_mock.dags = {}
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
_trigger_dag('dag_not_found', dag_bag_mock)
@mock.patch('airflow.api.common.experimental.trigger_dag.DagRun', spec=DagRun)
@@ -49,7 +50,7 @@ def test_trigger_dag_dag_run_exist(self, dag_bag_mock, dag_run_mock):
dag_bag_mock.dags = [dag_id]
dag_bag_mock.get_dag.return_value = dag
dag_run_mock.find.return_value = DagRun()
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
_trigger_dag(dag_id, dag_bag_mock)
@mock.patch('airflow.models.DAG')
@@ -66,7 +67,7 @@ def test_trigger_dag_include_subdags(self, dag_bag_mock, dag_run_mock, dag_mock)
triggers = _trigger_dag(dag_id, dag_bag_mock)
- self.assertEqual(3, len(triggers))
+ assert 3 == len(triggers)
@mock.patch('airflow.models.DAG')
@mock.patch('airflow.api.common.experimental.trigger_dag.DagRun', spec=DagRun)
@@ -82,7 +83,7 @@ def test_trigger_dag_include_nested_subdags(self, dag_bag_mock, dag_run_mock, da
triggers = _trigger_dag(dag_id, dag_bag_mock)
- self.assertEqual(3, len(triggers))
+ assert 3 == len(triggers)
@mock.patch('airflow.models.DagBag')
def test_trigger_dag_with_too_early_start_date(self, dag_bag_mock):
@@ -91,7 +92,7 @@ def test_trigger_dag_with_too_early_start_date(self, dag_bag_mock):
dag_bag_mock.dags = [dag_id]
dag_bag_mock.get_dag.return_value = dag
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
_trigger_dag(dag_id, dag_bag_mock, execution_date=timezone.datetime(2015, 7, 5, 10, 10, 0))
@mock.patch('airflow.models.DagBag')
@@ -124,4 +125,4 @@ def test_trigger_dag_with_conf(self, conf, expected_conf, dag_bag_mock):
triggers = _trigger_dag(dag_id, dag_bag_mock, conf=conf)
- self.assertEqual(triggers[0].conf, expected_conf)
+ assert triggers[0].conf == expected_conf
diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py b/tests/api_connexion/endpoints/test_connection_endpoint.py
index 4ca560edf28fd..d092e42dc6678 100644
--- a/tests/api_connexion/endpoints/test_connection_endpoint.py
+++ b/tests/api_connexion/endpoints/test_connection_endpoint.py
@@ -87,15 +87,12 @@ def test_delete_should_respond_404(self):
"/api/v1/connections/test-connection", environ_overrides={'REMOTE_USER': "test"}
)
assert response.status_code == 404
- self.assertEqual(
- response.json,
- {
- 'detail': "The Connection with connection_id: `test-connection` was not found",
- 'status': 404,
- 'title': 'Connection not found',
- 'type': EXCEPTIONS_LINK_MAP[404],
- },
- )
+ assert response.json == {
+ 'detail': "The Connection with connection_id: `test-connection` was not found",
+ 'status': 404,
+ 'title': 'Connection not found',
+ 'type': EXCEPTIONS_LINK_MAP[404],
+ }
def test_should_raises_401_unauthenticated(self):
response = self.client.delete("/api/v1/connections/test-connection")
@@ -128,32 +125,26 @@ def test_should_respond_200(self, session):
"/api/v1/connections/test-connection-id", environ_overrides={'REMOTE_USER': "test"}
)
assert response.status_code == 200
- self.assertEqual(
- response.json,
- {
- "connection_id": "test-connection-id",
- "conn_type": 'mysql',
- "host": 'mysql',
- "login": 'login',
- 'schema': 'testschema',
- 'port': 80,
- },
- )
+ assert response.json == {
+ "connection_id": "test-connection-id",
+ "conn_type": 'mysql',
+ "host": 'mysql',
+ "login": 'login',
+ 'schema': 'testschema',
+ 'port': 80,
+ }
def test_should_respond_404(self):
response = self.client.get(
"/api/v1/connections/invalid-connection", environ_overrides={'REMOTE_USER': "test"}
)
assert response.status_code == 404
- self.assertEqual(
- {
- 'detail': "The Connection with connection_id: `invalid-connection` was not found",
- 'status': 404,
- 'title': 'Connection not found',
- 'type': EXCEPTIONS_LINK_MAP[404],
- },
- response.json,
- )
+ assert {
+ 'detail': "The Connection with connection_id: `invalid-connection` was not found",
+ 'status': 404,
+ 'title': 'Connection not found',
+ 'type': EXCEPTIONS_LINK_MAP[404],
+ } == response.json
def test_should_raises_401_unauthenticated(self):
response = self.client.get("/api/v1/connections/test-connection-id")
@@ -173,30 +164,27 @@ def test_should_respond_200(self, session):
assert len(result) == 2
response = self.client.get("/api/v1/connections", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(
- response.json,
- {
- 'connections': [
- {
- "connection_id": "test-connection-id-1",
- "conn_type": 'test_type',
- "host": None,
- "login": None,
- 'schema': None,
- 'port': None,
- },
- {
- "connection_id": "test-connection-id-2",
- "conn_type": 'test_type',
- "host": None,
- "login": None,
- 'schema': None,
- 'port': None,
- },
- ],
- 'total_entries': 2,
- },
- )
+ assert response.json == {
+ 'connections': [
+ {
+ "connection_id": "test-connection-id-1",
+ "conn_type": 'test_type',
+ "host": None,
+ "login": None,
+ 'schema': None,
+ 'port': None,
+ },
+ {
+ "connection_id": "test-connection-id-2",
+ "conn_type": 'test_type',
+ "host": None,
+ "login": None,
+ 'schema': None,
+ 'port': None,
+ },
+ ],
+ 'total_entries': 2,
+ }
def test_should_raises_401_unauthenticated(self):
response = self.client.get("/api/v1/connections")
@@ -249,9 +237,9 @@ def test_handle_limit_offset(self, url, expected_conn_ids, session):
session.commit()
response = self.client.get(url, environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(response.json["total_entries"], 10)
+ assert response.json["total_entries"] == 10
conn_ids = [conn["connection_id"] for conn in response.json["connections"] if conn]
- self.assertEqual(conn_ids, expected_conn_ids)
+ assert conn_ids == expected_conn_ids
@provide_session
def test_should_respect_page_size_limit_default(self, session):
@@ -262,8 +250,8 @@ def test_should_respect_page_size_limit_default(self, session):
response = self.client.get("/api/v1/connections", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(response.json["total_entries"], 200)
- self.assertEqual(len(response.json["connections"]), 100)
+ assert response.json["total_entries"] == 200
+ assert len(response.json["connections"]) == 100
@provide_session
def test_limit_of_zero_should_return_default(self, session):
@@ -274,8 +262,8 @@ def test_limit_of_zero_should_return_default(self, session):
response = self.client.get("/api/v1/connections?limit=0", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(response.json["total_entries"], 200)
- self.assertEqual(len(response.json["connections"]), 100)
+ assert response.json["total_entries"] == 200
+ assert len(response.json["connections"]) == 100
@provide_session
@conf_vars({("api", "maximum_page_limit"): "150"})
@@ -286,7 +274,7 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session):
response = self.client.get("/api/v1/connections?limit=180", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(len(response.json['connections']), 150)
+ assert len(response.json['connections']) == 150
def _create_connections(self, count):
return [
@@ -329,19 +317,16 @@ def test_patch_should_respond_200_with_update_mask(self, session):
)
assert response.status_code == 200
connection = session.query(Connection).filter_by(conn_id=test_connection).first()
- self.assertEqual(connection.password, None)
- self.assertEqual(
- response.json,
- {
- "connection_id": test_connection, # not updated
- "conn_type": 'test_type', # Not updated
- "extra": None, # Not updated
- 'login': "login", # updated
- "port": 80, # updated
- "schema": None,
- "host": None,
- },
- )
+ assert connection.password is None
+ assert response.json == {
+ "connection_id": test_connection, # not updated
+ "conn_type": 'test_type', # Not updated
+ "extra": None, # Not updated
+ 'login': "login", # updated
+ "port": 80, # updated
+ "schema": None,
+ "host": None,
+ }
@parameterized.expand(
[
@@ -400,7 +385,7 @@ def test_patch_should_respond_400_for_invalid_fields_in_update_mask(
environ_overrides={'REMOTE_USER': "test"},
)
assert response.status_code == 400
- self.assertEqual(response.json['detail'], error_message)
+ assert response.json['detail'] == error_message
@parameterized.expand(
[
@@ -438,7 +423,7 @@ def test_patch_should_respond_400_for_invalid_update(self, payload, error_messag
"/api/v1/connections/test-connection-id", json=payload, environ_overrides={'REMOTE_USER': "test"}
)
assert response.status_code == 400
- self.assertIn(error_message, response.json['detail'])
+ assert error_message in response.json['detail']
def test_patch_should_respond_404_not_found(self):
payload = {"connection_id": "test-connection-id", "conn_type": "test-type", "port": 90}
@@ -446,15 +431,12 @@ def test_patch_should_respond_404_not_found(self):
"/api/v1/connections/test-connection-id", json=payload, environ_overrides={'REMOTE_USER': "test"}
)
assert response.status_code == 404
- self.assertEqual(
- {
- 'detail': "The Connection with connection_id: `test-connection-id` was not found",
- 'status': 404,
- 'title': 'Connection not found',
- 'type': EXCEPTIONS_LINK_MAP[404],
- },
- response.json,
- )
+ assert {
+ 'detail': "The Connection with connection_id: `test-connection-id` was not found",
+ 'status': 404,
+ 'title': 'Connection not found',
+ 'type': EXCEPTIONS_LINK_MAP[404],
+ } == response.json
@provide_session
def test_should_raises_401_unauthenticated(self, session):
@@ -478,7 +460,7 @@ def test_post_should_respond_200(self, session):
assert response.status_code == 200
connection = session.query(Connection).all()
assert len(connection) == 1
- self.assertEqual(connection[0].conn_id, 'test-connection-id')
+ assert connection[0].conn_id == 'test-connection-id'
def test_post_should_respond_400_for_invalid_payload(self):
payload = {
@@ -488,15 +470,12 @@ def test_post_should_respond_400_for_invalid_payload(self):
"/api/v1/connections", json=payload, environ_overrides={'REMOTE_USER': "test"}
)
assert response.status_code == 400
- self.assertEqual(
- response.json,
- {
- 'detail': "{'conn_type': ['Missing data for required field.']}",
- 'status': 400,
- 'title': 'Bad Request',
- 'type': EXCEPTIONS_LINK_MAP[400],
- },
- )
+ assert response.json == {
+ 'detail': "{'conn_type': ['Missing data for required field.']}",
+ 'status': 400,
+ 'title': 'Bad Request',
+ 'type': EXCEPTIONS_LINK_MAP[400],
+ }
def test_post_should_respond_409_already_exist(self):
payload = {"connection_id": "test-connection-id", "conn_type": 'test_type'}
@@ -509,15 +488,12 @@ def test_post_should_respond_409_already_exist(self):
"/api/v1/connections", json=payload, environ_overrides={'REMOTE_USER': "test"}
)
assert response.status_code == 409
- self.assertEqual(
- response.json,
- {
- 'detail': 'Connection already exist. ID: test-connection-id',
- 'status': 409,
- 'title': 'Conflict',
- 'type': EXCEPTIONS_LINK_MAP[409],
- },
- )
+ assert response.json == {
+ 'detail': 'Connection already exist. ID: test-connection-id',
+ 'status': 409,
+ 'title': 'Conflict',
+ 'type': EXCEPTIONS_LINK_MAP[409],
+ }
def test_should_raises_401_unauthenticated(self):
response = self.client.post(
diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py
index 614e8c25c9685..cf75435336d57 100644
--- a/tests/api_connexion/endpoints/test_dag_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_endpoint.py
@@ -116,21 +116,18 @@ def test_should_respond_200(self):
self._create_dag_models(1)
response = self.client.get("/api/v1/dags/TEST_DAG_1", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(
- {
- "dag_id": "TEST_DAG_1",
- "description": None,
- "fileloc": "/tmp/dag_1.py",
- "file_token": 'Ii90bXAvZGFnXzEucHki.EnmIdPaUPo26lHQClbWMbDFD1Pk',
- "is_paused": False,
- "is_subdag": False,
- "owners": [],
- "root_dag_id": None,
- "schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *"},
- "tags": [],
- },
- response.json,
- )
+ assert {
+ "dag_id": "TEST_DAG_1",
+ "description": None,
+ "fileloc": "/tmp/dag_1.py",
+ "file_token": 'Ii90bXAvZGFnXzEucHki.EnmIdPaUPo26lHQClbWMbDFD1Pk',
+ "is_paused": False,
+ "is_subdag": False,
+ "owners": [],
+ "root_dag_id": None,
+ "schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *"},
+ "tags": [],
+ } == response.json
@conf_vars({("webserver", "secret_key"): "mysecret"})
@provide_session
@@ -144,21 +141,18 @@ def test_should_respond_200_with_schedule_interval_none(self, session=None):
session.commit()
response = self.client.get("/api/v1/dags/TEST_DAG_1", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(
- {
- "dag_id": "TEST_DAG_1",
- "description": None,
- "fileloc": "/tmp/dag_1.py",
- "file_token": 'Ii90bXAvZGFnXzEucHki.EnmIdPaUPo26lHQClbWMbDFD1Pk',
- "is_paused": False,
- "is_subdag": False,
- "owners": [],
- "root_dag_id": None,
- "schedule_interval": None,
- "tags": [],
- },
- response.json,
- )
+ assert {
+ "dag_id": "TEST_DAG_1",
+ "description": None,
+ "fileloc": "/tmp/dag_1.py",
+ "file_token": 'Ii90bXAvZGFnXzEucHki.EnmIdPaUPo26lHQClbWMbDFD1Pk',
+ "is_paused": False,
+ "is_subdag": False,
+ "owners": [],
+ "root_dag_id": None,
+ "schedule_interval": None,
+ "tags": [],
+ } == response.json
def test_should_respond_200_with_granular_dag_access(self):
self._create_dag_models(1)
@@ -331,15 +325,12 @@ def test_should_raise_404_when_dag_is_not_found(self):
"/api/v1/dags/non_existing_dag_id/details", environ_overrides={'REMOTE_USER': "test"}
)
assert response.status_code == 404
- self.assertEqual(
- response.json,
- {
- 'detail': 'The DAG with dag_id: non_existing_dag_id was not found',
- 'status': 404,
- 'title': 'DAG not found',
- 'type': EXCEPTIONS_LINK_MAP[404],
- },
- )
+ assert response.json == {
+ 'detail': 'The DAG with dag_id: non_existing_dag_id was not found',
+ 'status': 404,
+ 'title': 'DAG not found',
+ 'type': EXCEPTIONS_LINK_MAP[404],
+ }
class TestGetDags(TestDagEndpoint):
@@ -350,44 +341,41 @@ def test_should_respond_200(self):
file_token = SERIALIZER.dumps("/tmp/dag_1.py")
file_token2 = SERIALIZER.dumps("/tmp/dag_2.py")
assert response.status_code == 200
- self.assertEqual(
- {
- "dags": [
- {
- "dag_id": "TEST_DAG_1",
- "description": None,
- "fileloc": "/tmp/dag_1.py",
- "file_token": file_token,
- "is_paused": False,
- "is_subdag": False,
- "owners": [],
- "root_dag_id": None,
- "schedule_interval": {
- "__type": "CronExpression",
- "value": "2 2 * * *",
- },
- "tags": [],
+ assert {
+ "dags": [
+ {
+ "dag_id": "TEST_DAG_1",
+ "description": None,
+ "fileloc": "/tmp/dag_1.py",
+ "file_token": file_token,
+ "is_paused": False,
+ "is_subdag": False,
+ "owners": [],
+ "root_dag_id": None,
+ "schedule_interval": {
+ "__type": "CronExpression",
+ "value": "2 2 * * *",
},
- {
- "dag_id": "TEST_DAG_2",
- "description": None,
- "fileloc": "/tmp/dag_2.py",
- "file_token": file_token2,
- "is_paused": False,
- "is_subdag": False,
- "owners": [],
- "root_dag_id": None,
- "schedule_interval": {
- "__type": "CronExpression",
- "value": "2 2 * * *",
- },
- "tags": [],
+ "tags": [],
+ },
+ {
+ "dag_id": "TEST_DAG_2",
+ "description": None,
+ "fileloc": "/tmp/dag_2.py",
+ "file_token": file_token2,
+ "is_paused": False,
+ "is_subdag": False,
+ "owners": [],
+ "root_dag_id": None,
+ "schedule_interval": {
+ "__type": "CronExpression",
+ "value": "2 2 * * *",
},
- ],
- "total_entries": 2,
- },
- response.json,
- )
+ "tags": [],
+ },
+ ],
+ "total_entries": 2,
+ } == response.json
def test_should_respond_200_with_granular_dag_access(self):
self._create_dag_models(3)
@@ -435,8 +423,8 @@ def test_should_respond_200_and_handle_pagination(self, url, expected_dag_ids):
dag_ids = [dag["dag_id"] for dag in response.json["dags"]]
- self.assertEqual(expected_dag_ids, dag_ids)
- self.assertEqual(10, response.json["total_entries"])
+ assert expected_dag_ids == dag_ids
+ assert 10 == response.json["total_entries"]
def test_should_respond_200_default_limit(self):
self._create_dag_models(101)
@@ -445,8 +433,8 @@ def test_should_respond_200_default_limit(self):
assert response.status_code == 200
- self.assertEqual(100, len(response.json["dags"]))
- self.assertEqual(101, response.json["total_entries"])
+ assert 100 == len(response.json["dags"])
+ assert 101 == response.json["total_entries"]
def test_should_raises_401_unauthenticated(self):
response = self.client.get("api/v1/dags")
@@ -474,7 +462,7 @@ def test_should_respond_200_on_patch_is_paused(self):
},
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
expected_response = {
"dag_id": "TEST_DAG_1",
@@ -491,7 +479,7 @@ def test_should_respond_200_on_patch_is_paused(self):
},
"tags": [],
}
- self.assertEqual(response.json, expected_response)
+ assert response.json == expected_response
def test_should_respond_200_on_patch_with_granular_dag_access(self):
self._create_dag_models(1)
@@ -514,20 +502,17 @@ def test_should_respond_400_on_invalid_request(self):
}
dag_model = self._create_dag_model()
response = self.client.patch(f"/api/v1/dags/{dag_model.dag_id}", json=patch_body)
- self.assertEqual(response.status_code, 400)
- self.assertEqual(
- response.json,
- {
- 'detail': "Property is read-only - 'schedule_interval'",
- 'status': 400,
- 'title': 'Bad Request',
- 'type': EXCEPTIONS_LINK_MAP[400],
- },
- )
+ assert response.status_code == 400
+ assert response.json == {
+ 'detail': "Property is read-only - 'schedule_interval'",
+ 'status': 400,
+ 'title': 'Bad Request',
+ 'type': EXCEPTIONS_LINK_MAP[400],
+ }
def test_should_respond_404(self):
response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={'REMOTE_USER': "test"})
- self.assertEqual(response.status_code, 404)
+ assert response.status_code == 404
@provide_session
def _create_dag_model(self, session=None):
@@ -559,7 +544,7 @@ def test_should_respond_200_with_update_mask(self):
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
expected_response = {
"dag_id": "TEST_DAG_1",
"description": None,
@@ -575,7 +560,7 @@ def test_should_respond_200_with_update_mask(self):
},
"tags": [],
}
- self.assertEqual(response.json, expected_response)
+ assert response.json == expected_response
@parameterized.expand(
[
@@ -603,8 +588,8 @@ def test_should_respond_400_for_invalid_fields_in_update_mask(self, payload, upd
json=payload,
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(response.status_code, 400)
- self.assertEqual(response.json['detail'], error_message)
+ assert response.status_code == 400
+ assert response.json['detail'] == error_message
def test_should_respond_403_unauthorized(self):
dag_model = self._create_dag_model()
diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
index 794b7f256f4fc..48960fa2365f7 100644
--- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
@@ -129,27 +129,24 @@ def test_should_respond_204(self, session):
response = self.client.delete(
"api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", environ_overrides={'REMOTE_USER': "test"}
)
- self.assertEqual(response.status_code, 204)
+ assert response.status_code == 204
# Check if the Dag Run is deleted from the database
response = self.client.get(
"api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", environ_overrides={'REMOTE_USER': "test"}
)
- self.assertEqual(response.status_code, 404)
+ assert response.status_code == 404
def test_should_respond_404(self):
response = self.client.delete(
"api/v1/dags/INVALID_DAG_RUN/dagRuns/INVALID_DAG_RUN", environ_overrides={'REMOTE_USER': "test"}
)
- self.assertEqual(response.status_code, 404)
- self.assertEqual(
- response.json,
- {
- "detail": "DAGRun with DAG ID: 'INVALID_DAG_RUN' and DagRun ID: 'INVALID_DAG_RUN' not found",
- "status": 404,
- "title": "Not Found",
- "type": EXCEPTIONS_LINK_MAP[404],
- },
- )
+ assert response.status_code == 404
+ assert response.json == {
+ "detail": "DAGRun with DAG ID: 'INVALID_DAG_RUN' and DagRun ID: 'INVALID_DAG_RUN' not found",
+ "status": 404,
+ "title": "Not Found",
+ "type": EXCEPTIONS_LINK_MAP[404],
+ }
@provide_session
def test_should_raises_401_unauthenticated(self, session):
@@ -361,7 +358,7 @@ def test_should_return_conf_max_if_req_max_above_conf(self):
"api/v1/dags/TEST_DAG_ID/dagRuns?limit=180", environ_overrides={'REMOTE_USER': "test"}
)
assert response.status_code == 200
- self.assertEqual(len(response.json["dag_runs"]), 150)
+ assert len(response.json["dag_runs"]) == 150
def _create_dag_runs(self, count):
dag_runs = [
@@ -770,7 +767,7 @@ def test_naive_date_filters_raises_400(self, payload, expected_response):
"api/v1/dags/~/dagRuns/list", json=payload, environ_overrides={'REMOTE_USER': "test"}
)
assert response.status_code == 400
- self.assertEqual(response.json['detail'], expected_response)
+ assert response.json['detail'] == expected_response
@parameterized.expand(
[
@@ -818,20 +815,17 @@ def test_should_respond_200(self, name, request_json, session):
response = self.client.post(
"api/v1/dags/TEST_DAG_ID/dagRuns", json=request_json, environ_overrides={'REMOTE_USER': "test"}
)
- self.assertEqual(response.status_code, 200)
- self.assertEqual(
- {
- "conf": {},
- "dag_id": "TEST_DAG_ID",
- "dag_run_id": response.json["dag_run_id"],
- "end_date": None,
- "execution_date": response.json["execution_date"],
- "external_trigger": True,
- "start_date": response.json["start_date"],
- "state": "running",
- },
- response.json,
- )
+ assert response.status_code == 200
+ assert {
+ "conf": {},
+ "dag_id": "TEST_DAG_ID",
+ "dag_run_id": response.json["dag_run_id"],
+ "end_date": None,
+ "execution_date": response.json["execution_date"],
+ "external_trigger": True,
+ "start_date": response.json["start_date"],
+ "state": "running",
+ } == response.json
@parameterized.expand(
[
@@ -847,8 +841,8 @@ def test_should_response_400_for_naive_datetime_and_bad_datetime(self, data, exp
response = self.client.post(
"api/v1/dags/TEST_DAG_ID/dagRuns", json=data, environ_overrides={'REMOTE_USER': "test"}
)
- self.assertEqual(response.status_code, 400)
- self.assertEqual(response.json['detail'], expected)
+ assert response.status_code == 400
+ assert response.json['detail'] == expected
def test_response_404(self):
response = self.client.post(
@@ -856,16 +850,13 @@ def test_response_404(self):
json={"dag_run_id": "TEST_DAG_RUN", "execution_date": self.default_time},
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(response.status_code, 404)
- self.assertEqual(
- {
- "detail": "DAG with dag_id: 'TEST_DAG_ID' not found",
- "status": 404,
- "title": "DAG not found",
- "type": EXCEPTIONS_LINK_MAP[404],
- },
- response.json,
- )
+ assert response.status_code == 404
+ assert {
+ "detail": "DAG with dag_id: 'TEST_DAG_ID' not found",
+ "status": 404,
+ "title": "DAG not found",
+ "type": EXCEPTIONS_LINK_MAP[404],
+ } == response.json
@parameterized.expand(
[
@@ -903,8 +894,8 @@ def test_response_400(self, name, url, request_json, expected_response, session)
session.add(dag_instance)
session.commit()
response = self.client.post(url, json=request_json, environ_overrides={'REMOTE_USER': "test"})
- self.assertEqual(response.status_code, 400, response.data)
- self.assertEqual(expected_response, response.json)
+ assert response.status_code == 400, response.data
+ assert expected_response == response.json
def test_response_409(self):
self._create_test_dag_run()
@@ -916,17 +907,14 @@ def test_response_409(self):
},
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(response.status_code, 409, response.data)
- self.assertEqual(
- response.json,
- {
- "detail": "DAGRun with DAG ID: 'TEST_DAG_ID' and "
- "DAGRun ID: 'TEST_DAG_RUN_ID_1' already exists",
- "status": 409,
- "title": "Conflict",
- "type": EXCEPTIONS_LINK_MAP[409],
- },
- )
+ assert response.status_code == 409, response.data
+ assert response.json == {
+ "detail": "DAGRun with DAG ID: 'TEST_DAG_ID' and "
+ "DAGRun ID: 'TEST_DAG_RUN_ID_1' already exists",
+ "status": 409,
+ "title": "Conflict",
+ "type": EXCEPTIONS_LINK_MAP[409],
+ }
def test_should_raises_401_unauthenticated(self):
response = self.client.post(
diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py b/tests/api_connexion/endpoints/test_dag_source_endpoint.py
index baee62027f3d2..4ad8236634182 100644
--- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py
@@ -91,9 +91,9 @@ def test_should_respond_200_text(self, store_dag_code):
url, headers={"Accept": "text/plain"}, environ_overrides={'REMOTE_USER': "test"}
)
- self.assertEqual(200, response.status_code)
- self.assertIn(dag_docstring, response.data.decode())
- self.assertEqual('text/plain', response.headers['Content-Type'])
+ assert 200 == response.status_code
+ assert dag_docstring in response.data.decode()
+ assert 'text/plain' == response.headers['Content-Type']
@parameterized.expand([(True,), (False,)])
def test_should_respond_200_json(self, store_dag_code):
@@ -111,9 +111,9 @@ def test_should_respond_200_json(self, store_dag_code):
url, headers={"Accept": 'application/json'}, environ_overrides={'REMOTE_USER': "test"}
)
- self.assertEqual(200, response.status_code)
- self.assertIn(dag_docstring, response.json['content'])
- self.assertEqual('application/json', response.headers['Content-Type'])
+ assert 200 == response.status_code
+ assert dag_docstring in response.json['content']
+ assert 'application/json' == response.headers['Content-Type']
@parameterized.expand([(True,), (False,)])
def test_should_respond_406(self, store_dag_code):
@@ -130,7 +130,7 @@ def test_should_respond_406(self, store_dag_code):
url, headers={"Accept": 'image/webp'}, environ_overrides={'REMOTE_USER': "test"}
)
- self.assertEqual(406, response.status_code)
+ assert 406 == response.status_code
@parameterized.expand([(True,), (False,)])
def test_should_respond_404(self, store_dag_code):
@@ -143,7 +143,7 @@ def test_should_respond_404(self, store_dag_code):
url, headers={"Accept": 'application/json'}, environ_overrides={'REMOTE_USER': "test"}
)
- self.assertEqual(404, response.status_code)
+ assert 404 == response.status_code
def test_should_raises_401_unauthenticated(self):
serializer = URLSafeSerializer(conf.get('webserver', 'SECRET_KEY'))
diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py b/tests/api_connexion/endpoints/test_event_log_endpoint.py
index 11cf1553834a9..32020cf3ffa11 100644
--- a/tests/api_connexion/endpoints/test_event_log_endpoint.py
+++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py
@@ -89,27 +89,26 @@ def test_should_respond_200(self, session):
f"/api/v1/eventLogs/{event_log_id}", environ_overrides={'REMOTE_USER': "test"}
)
assert response.status_code == 200
- self.assertEqual(
- response.json,
- {
- "event_log_id": event_log_id,
- "event": "TEST_EVENT",
- "dag_id": "TEST_DAG_ID",
- "task_id": "TEST_TASK_ID",
- "execution_date": self.default_time,
- "owner": 'airflow',
- "when": self.default_time,
- "extra": None,
- },
- )
+ assert response.json == {
+ "event_log_id": event_log_id,
+ "event": "TEST_EVENT",
+ "dag_id": "TEST_DAG_ID",
+ "task_id": "TEST_TASK_ID",
+ "execution_date": self.default_time,
+ "owner": 'airflow',
+ "when": self.default_time,
+ "extra": None,
+ }
def test_should_respond_404(self):
response = self.client.get("/api/v1/eventLogs/1", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 404
- self.assertEqual(
- {'detail': None, 'status': 404, 'title': 'Event Log not found', 'type': EXCEPTIONS_LINK_MAP[404]},
- response.json,
- )
+ assert {
+ 'detail': None,
+ 'status': 404,
+ 'title': 'Event Log not found',
+ 'type': EXCEPTIONS_LINK_MAP[404],
+ } == response.json
@provide_session
def test_should_raises_401_unauthenticated(self, session):
@@ -152,44 +151,41 @@ def test_should_respond_200(self, session):
session.commit()
response = self.client.get("/api/v1/eventLogs", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(
- response.json,
- {
- "event_logs": [
- {
- "event_log_id": log_model_1.id,
- "event": "TEST_EVENT_1",
- "dag_id": "TEST_DAG_ID",
- "task_id": "TEST_TASK_ID",
- "execution_date": self.default_time,
- "owner": 'airflow',
- "when": self.default_time,
- "extra": None,
- },
- {
- "event_log_id": log_model_2.id,
- "event": "TEST_EVENT_2",
- "dag_id": "TEST_DAG_ID",
- "task_id": "TEST_TASK_ID",
- "execution_date": self.default_time,
- "owner": 'airflow',
- "when": self.default_time_2,
- "extra": None,
- },
- {
- "event_log_id": log_model_3.id,
- "event": "cli_scheduler",
- "dag_id": None,
- "task_id": None,
- "execution_date": None,
- "owner": 'root',
- "when": self.default_time_2,
- "extra": '{"host_name": "e24b454f002a"}',
- },
- ],
- "total_entries": 3,
- },
- )
+ assert response.json == {
+ "event_logs": [
+ {
+ "event_log_id": log_model_1.id,
+ "event": "TEST_EVENT_1",
+ "dag_id": "TEST_DAG_ID",
+ "task_id": "TEST_TASK_ID",
+ "execution_date": self.default_time,
+ "owner": 'airflow',
+ "when": self.default_time,
+ "extra": None,
+ },
+ {
+ "event_log_id": log_model_2.id,
+ "event": "TEST_EVENT_2",
+ "dag_id": "TEST_DAG_ID",
+ "task_id": "TEST_TASK_ID",
+ "execution_date": self.default_time,
+ "owner": 'airflow',
+ "when": self.default_time_2,
+ "extra": None,
+ },
+ {
+ "event_log_id": log_model_3.id,
+ "event": "cli_scheduler",
+ "dag_id": None,
+ "task_id": None,
+ "execution_date": None,
+ "owner": 'root',
+ "when": self.default_time_2,
+ "extra": '{"host_name": "e24b454f002a"}',
+ },
+ ],
+ "total_entries": 3,
+ }
@provide_session
def test_should_raises_401_unauthenticated(self, session):
@@ -258,9 +254,9 @@ def test_handle_limit_and_offset(self, url, expected_events, session):
response = self.client.get(url, environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(response.json["total_entries"], 10)
+ assert response.json["total_entries"] == 10
events = [event_log["event"] for event_log in response.json["event_logs"]]
- self.assertEqual(events, expected_events)
+ assert events == expected_events
@provide_session
def test_should_respect_page_size_limit_default(self, session):
@@ -271,8 +267,8 @@ def test_should_respect_page_size_limit_default(self, session):
response = self.client.get("/api/v1/eventLogs", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(response.json["total_entries"], 200)
- self.assertEqual(len(response.json["event_logs"]), 100) # default 100
+ assert response.json["total_entries"] == 200
+ assert len(response.json["event_logs"]) == 100 # default 100
@provide_session
@conf_vars({("api", "maximum_page_limit"): "150"})
@@ -283,7 +279,7 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session):
response = self.client.get("/api/v1/eventLogs?limit=180", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(len(response.json['event_logs']), 150)
+ assert len(response.json['event_logs']) == 150
def _create_event_logs(self, count):
return [
diff --git a/tests/api_connexion/endpoints/test_extra_link_endpoint.py b/tests/api_connexion/endpoints/test_extra_link_endpoint.py
index d67e3cc2e5fe4..3864d1f7c7df8 100644
--- a/tests/api_connexion/endpoints/test_extra_link_endpoint.py
+++ b/tests/api_connexion/endpoints/test_extra_link_endpoint.py
@@ -129,16 +129,13 @@ def test_should_respond_404(self, name, url, expected_title, expected_detail):
del name
response = self.client.get(url, environ_overrides={'REMOTE_USER': "test"})
- self.assertEqual(404, response.status_code)
- self.assertEqual(
- {
- "detail": expected_detail,
- "status": 404,
- "title": expected_title,
- "type": EXCEPTIONS_LINK_MAP[404],
- },
- response.json,
- )
+ assert 404 == response.status_code
+ assert {
+ "detail": expected_detail,
+ "status": 404,
+ "title": expected_title,
+ "type": EXCEPTIONS_LINK_MAP[404],
+ } == response.json
def test_should_raise_403_forbidden(self):
response = self.client.get(
@@ -161,10 +158,10 @@ def test_should_respond_200(self):
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(200, response.status_code, response.data)
- self.assertEqual(
- {"BigQuery Console": "https://console.cloud.google.com/bigquery?j=TEST_JOB_ID"}, response.json
- )
+ assert 200 == response.status_code, response.data
+ assert {
+ "BigQuery Console": "https://console.cloud.google.com/bigquery?j=TEST_JOB_ID"
+ } == response.json
@mock_plugin_manager(plugins=[])
def test_should_respond_200_missing_xcom(self):
@@ -173,11 +170,8 @@ def test_should_respond_200_missing_xcom(self):
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(200, response.status_code, response.data)
- self.assertEqual(
- {"BigQuery Console": None},
- response.json,
- )
+ assert 200 == response.status_code, response.data
+ assert {"BigQuery Console": None} == response.json
@mock_plugin_manager(plugins=[])
def test_should_respond_200_multiple_links(self):
@@ -193,14 +187,11 @@ def test_should_respond_200_multiple_links(self):
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(200, response.status_code, response.data)
- self.assertEqual(
- {
- "BigQuery Console #1": "https://console.cloud.google.com/bigquery?j=TEST_JOB_ID_1",
- "BigQuery Console #2": "https://console.cloud.google.com/bigquery?j=TEST_JOB_ID_2",
- },
- response.json,
- )
+ assert 200 == response.status_code, response.data
+ assert {
+ "BigQuery Console #1": "https://console.cloud.google.com/bigquery?j=TEST_JOB_ID_1",
+ "BigQuery Console #2": "https://console.cloud.google.com/bigquery?j=TEST_JOB_ID_2",
+ } == response.json
@mock_plugin_manager(plugins=[])
def test_should_respond_200_multiple_links_missing_xcom(self):
@@ -209,11 +200,8 @@ def test_should_respond_200_multiple_links_missing_xcom(self):
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(200, response.status_code, response.data)
- self.assertEqual(
- {"BigQuery Console #1": None, "BigQuery Console #2": None},
- response.json,
- )
+ assert 200 == response.status_code, response.data
+ assert {"BigQuery Console #1": None, "BigQuery Console #2": None} == response.json
def test_should_respond_200_support_plugins(self):
class GoogleLink(BaseOperatorLink):
@@ -248,15 +236,12 @@ class AirflowTestPlugin(AirflowPlugin):
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(200, response.status_code, response.data)
- self.assertEqual(
- {
- "BigQuery Console": None,
- "Google": "https://www.google.com",
- "S3": (
- "https://s3.amazonaws.com/airflow-logs/"
- "TEST_DAG_ID/TEST_SINGLE_QUERY/2020-01-01T00%3A00%3A00%2B00%3A00"
- ),
- },
- response.json,
- )
+ assert 200 == response.status_code, response.data
+ assert {
+ "BigQuery Console": None,
+ "Google": "https://www.google.com",
+ "S3": (
+ "https://s3.amazonaws.com/airflow-logs/"
+ "TEST_DAG_ID/TEST_SINGLE_QUERY/2020-01-01T00%3A00%3A00%2B00%3A00"
+ ),
+ } == response.json
diff --git a/tests/api_connexion/endpoints/test_health_endpoint.py b/tests/api_connexion/endpoints/test_health_endpoint.py
index 9bc60065d79a4..defb97b5eac63 100644
--- a/tests/api_connexion/endpoints/test_health_endpoint.py
+++ b/tests/api_connexion/endpoints/test_health_endpoint.py
@@ -58,11 +58,11 @@ def test_healthy_scheduler_status(self, session):
)
session.commit()
resp_json = self.client.get("/api/v1/health").json
- self.assertEqual("healthy", resp_json["metadatabase"]["status"])
- self.assertEqual("healthy", resp_json["scheduler"]["status"])
- self.assertEqual(
- last_scheduler_heartbeat_for_testing_1.isoformat(),
- resp_json["scheduler"]["latest_scheduler_heartbeat"],
+ assert "healthy" == resp_json["metadatabase"]["status"]
+ assert "healthy" == resp_json["scheduler"]["status"]
+ assert (
+ last_scheduler_heartbeat_for_testing_1.isoformat()
+ == resp_json["scheduler"]["latest_scheduler_heartbeat"]
)
@provide_session
@@ -77,22 +77,22 @@ def test_unhealthy_scheduler_is_slow(self, session):
)
session.commit()
resp_json = self.client.get("/api/v1/health").json
- self.assertEqual("healthy", resp_json["metadatabase"]["status"])
- self.assertEqual("unhealthy", resp_json["scheduler"]["status"])
- self.assertEqual(
- last_scheduler_heartbeat_for_testing_2.isoformat(),
- resp_json["scheduler"]["latest_scheduler_heartbeat"],
+ assert "healthy" == resp_json["metadatabase"]["status"]
+ assert "unhealthy" == resp_json["scheduler"]["status"]
+ assert (
+ last_scheduler_heartbeat_for_testing_2.isoformat()
+ == resp_json["scheduler"]["latest_scheduler_heartbeat"]
)
def test_unhealthy_scheduler_no_job(self):
resp_json = self.client.get("/api/v1/health").json
- self.assertEqual("healthy", resp_json["metadatabase"]["status"])
- self.assertEqual("unhealthy", resp_json["scheduler"]["status"])
- self.assertIsNone(resp_json["scheduler"]["latest_scheduler_heartbeat"])
+ assert "healthy" == resp_json["metadatabase"]["status"]
+ assert "unhealthy" == resp_json["scheduler"]["status"]
+ assert resp_json["scheduler"]["latest_scheduler_heartbeat"] is None
@mock.patch("airflow.api_connexion.endpoints.health_endpoint.SchedulerJob.most_recent_job")
def test_unhealthy_metadatabase_status(self, mock_scheduler_most_recent_job):
mock_scheduler_most_recent_job.side_effect = Exception
resp_json = self.client.get("/api/v1/health").json
- self.assertEqual("unhealthy", resp_json["metadatabase"]["status"])
- self.assertIsNone(resp_json["scheduler"]["latest_scheduler_heartbeat"])
+ assert "unhealthy" == resp_json["metadatabase"]["status"]
+ assert resp_json["scheduler"]["latest_scheduler_heartbeat"] is None
diff --git a/tests/api_connexion/endpoints/test_import_error_endpoint.py b/tests/api_connexion/endpoints/test_import_error_endpoint.py
index 66eef6098222c..c8c144b1c9175 100644
--- a/tests/api_connexion/endpoints/test_import_error_endpoint.py
+++ b/tests/api_connexion/endpoints/test_import_error_endpoint.py
@@ -81,28 +81,22 @@ def test_response_200(self, session):
assert response.status_code == 200
response_data = response.json
response_data["import_error_id"] = 1
- self.assertEqual(
- {
- "filename": "Lorem_ipsum.py",
- "import_error_id": 1,
- "stack_trace": "Lorem ipsum",
- "timestamp": "2020-06-10T12:00:00+00:00",
- },
- response_data,
- )
+ assert {
+ "filename": "Lorem_ipsum.py",
+ "import_error_id": 1,
+ "stack_trace": "Lorem ipsum",
+ "timestamp": "2020-06-10T12:00:00+00:00",
+ } == response_data
def test_response_404(self):
response = self.client.get("/api/v1/importErrors/2", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 404
- self.assertEqual(
- {
- "detail": "The ImportError with import_error_id: `2` was not found",
- "status": 404,
- "title": "Import error not found",
- "type": EXCEPTIONS_LINK_MAP[404],
- },
- response.json,
- )
+ assert {
+ "detail": "The ImportError with import_error_id: `2` was not found",
+ "status": 404,
+ "title": "Import error not found",
+ "type": EXCEPTIONS_LINK_MAP[404],
+ } == response.json
@provide_session
def test_should_raises_401_unauthenticated(self, session):
@@ -144,26 +138,23 @@ def test_get_import_errors(self, session):
assert response.status_code == 200
response_data = response.json
self._normalize_import_errors(response_data['import_errors'])
- self.assertEqual(
- {
- "import_errors": [
- {
- "filename": "Lorem_ipsum.py",
- "import_error_id": 1,
- "stack_trace": "Lorem ipsum",
- "timestamp": "2020-06-10T12:00:00+00:00",
- },
- {
- "filename": "Lorem_ipsum.py",
- "import_error_id": 2,
- "stack_trace": "Lorem ipsum",
- "timestamp": "2020-06-10T12:00:00+00:00",
- },
- ],
- "total_entries": 2,
- },
- response_data,
- )
+ assert {
+ "import_errors": [
+ {
+ "filename": "Lorem_ipsum.py",
+ "import_error_id": 1,
+ "stack_trace": "Lorem ipsum",
+ "timestamp": "2020-06-10T12:00:00+00:00",
+ },
+ {
+ "filename": "Lorem_ipsum.py",
+ "import_error_id": 2,
+ "stack_trace": "Lorem ipsum",
+ "timestamp": "2020-06-10T12:00:00+00:00",
+ },
+ ],
+ "total_entries": 2,
+ } == response_data
@provide_session
def test_should_raises_401_unauthenticated(self, session):
@@ -213,7 +204,7 @@ def test_limit_and_offset(self, url, expected_import_error_ids, session):
assert response.status_code == 200
import_ids = [pool["filename"] for pool in response.json["import_errors"]]
- self.assertEqual(import_ids, expected_import_error_ids)
+ assert import_ids == expected_import_error_ids
@provide_session
def test_should_respect_page_size_limit_default(self, session):
@@ -229,7 +220,7 @@ def test_should_respect_page_size_limit_default(self, session):
session.commit()
response = self.client.get("/api/v1/importErrors", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(len(response.json['import_errors']), 100)
+ assert len(response.json['import_errors']) == 100
@provide_session
@conf_vars({("api", "maximum_page_limit"): "150"})
@@ -248,4 +239,4 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session):
"/api/v1/importErrors?limit=180", environ_overrides={'REMOTE_USER': "test"}
)
assert response.status_code == 200
- self.assertEqual(len(response.json['import_errors']), 150)
+ assert len(response.json['import_errors']) == 150
diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py
index 094fc5f47c5f3..509cbe7c0b6d2 100644
--- a/tests/api_connexion/endpoints/test_log_endpoint.py
+++ b/tests/api_connexion/endpoints/test_log_endpoint.py
@@ -167,13 +167,13 @@ def test_should_respond_200_json(self, session):
expected_filename = "{}/{}/{}/{}/1.log".format(
self.log_dir, self.DAG_ID, self.TASK_ID, self.default_time.replace(":", ".")
)
- self.assertEqual(
- response.json['content'],
- f"[('', '*** Reading local file: {expected_filename}\\nLog for testing.')]",
+ assert (
+ response.json['content']
+ == f"[('', '*** Reading local file: {expected_filename}\\nLog for testing.')]"
)
info = serializer.loads(response.json['continuation_token'])
- self.assertEqual(info, {'end_of_log': True})
- self.assertEqual(200, response.status_code)
+ assert info == {'end_of_log': True}
+ assert 200 == response.status_code
@provide_session
def test_should_respond_200_text_plain(self, session):
@@ -191,10 +191,10 @@ def test_should_respond_200_text_plain(self, session):
expected_filename = "{}/{}/{}/{}/1.log".format(
self.log_dir, self.DAG_ID, self.TASK_ID, self.default_time.replace(':', '.')
)
- self.assertEqual(200, response.status_code)
- self.assertEqual(
- response.data.decode('utf-8'),
- f"\n*** Reading local file: {expected_filename}\nLog for testing.\n",
+ assert 200 == response.status_code
+ assert (
+ response.data.decode('utf-8')
+ == f"\n*** Reading local file: {expected_filename}\nLog for testing.\n"
)
@provide_session
@@ -209,8 +209,8 @@ def test_get_logs_response_with_ti_equal_to_none(self, session):
f"taskInstances/Invalid-Task-ID/logs/1?token={token}",
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(response.status_code, 400)
- self.assertEqual(response.json['detail'], "Task instance did not exist in the DB")
+ assert response.status_code == 400
+ assert response.json['detail'] == "Task instance did not exist in the DB"
@provide_session
def test_get_logs_with_metadata_as_download_large_file(self, session):
@@ -229,10 +229,10 @@ def test_get_logs_with_metadata_as_download_large_file(self, session):
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertIn('1st line', response.data.decode('utf-8'))
- self.assertIn('2nd line', response.data.decode('utf-8'))
- self.assertIn('3rd line', response.data.decode('utf-8'))
- self.assertNotIn('should never be read', response.data.decode('utf-8'))
+ assert '1st line' in response.data.decode('utf-8')
+ assert '2nd line' in response.data.decode('utf-8')
+ assert '3rd line' in response.data.decode('utf-8')
+ assert 'should never be read' not in response.data.decode('utf-8')
@mock.patch("airflow.api_connexion.endpoints.log_endpoint.TaskLogReader")
def test_get_logs_for_handler_without_read_method(self, mock_log_reader):
@@ -249,8 +249,8 @@ def test_get_logs_for_handler_without_read_method(self, mock_log_reader):
headers={'Content-Type': 'application/jso'},
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(400, response.status_code)
- self.assertIn('Task log handler does not support read logs.', response.data.decode('utf-8'))
+ assert 400 == response.status_code
+ assert 'Task log handler does not support read logs.' in response.data.decode('utf-8')
@provide_session
def test_bad_signature_raises(self, session):
@@ -263,15 +263,12 @@ def test_bad_signature_raises(self, session):
headers={'Accept': 'application/json'},
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(
- response.json,
- {
- 'detail': None,
- 'status': 400,
- 'title': "Bad Signature. Please use only the tokens provided by the API.",
- 'type': EXCEPTIONS_LINK_MAP[400],
- },
- )
+ assert response.json == {
+ 'detail': None,
+ 'status': 400,
+ 'title': "Bad Signature. Please use only the tokens provided by the API.",
+ 'type': EXCEPTIONS_LINK_MAP[400],
+ }
def test_raises_404_for_invalid_dag_run_id(self):
response = self.client.get(
@@ -280,10 +277,12 @@ def test_raises_404_for_invalid_dag_run_id(self):
headers={'Accept': 'application/json'},
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(
- response.json,
- {'detail': None, 'status': 404, 'title': "DAG Run not found", 'type': EXCEPTIONS_LINK_MAP[404]},
- )
+ assert response.json == {
+ 'detail': None,
+ 'status': 404,
+ 'title': "DAG Run not found",
+ 'type': EXCEPTIONS_LINK_MAP[404],
+ }
def test_should_raises_401_unauthenticated(self):
key = self.app.config["SECRET_KEY"]
diff --git a/tests/api_connexion/endpoints/test_pool_endpoint.py b/tests/api_connexion/endpoints/test_pool_endpoint.py
index 16e86078b48e7..7bf7071da896e 100644
--- a/tests/api_connexion/endpoints/test_pool_endpoint.py
+++ b/tests/api_connexion/endpoints/test_pool_endpoint.py
@@ -72,30 +72,27 @@ def test_response_200(self, session):
assert len(result) == 2 # accounts for the default pool as well
response = self.client.get("/api/v1/pools", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(
- {
- "pools": [
- {
- "name": "default_pool",
- "slots": 128,
- "occupied_slots": 0,
- "running_slots": 0,
- "queued_slots": 0,
- "open_slots": 128,
- },
- {
- "name": "test_pool_a",
- "slots": 3,
- "occupied_slots": 0,
- "running_slots": 0,
- "queued_slots": 0,
- "open_slots": 3,
- },
- ],
- "total_entries": 2,
- },
- response.json,
- )
+ assert {
+ "pools": [
+ {
+ "name": "default_pool",
+ "slots": 128,
+ "occupied_slots": 0,
+ "running_slots": 0,
+ "queued_slots": 0,
+ "open_slots": 128,
+ },
+ {
+ "name": "test_pool_a",
+ "slots": 3,
+ "occupied_slots": 0,
+ "running_slots": 0,
+ "queued_slots": 0,
+ "open_slots": 3,
+ },
+ ],
+ "total_entries": 2,
+ } == response.json
def test_should_raises_401_unauthenticated(self):
response = self.client.get("/api/v1/pools")
@@ -134,11 +131,11 @@ def test_limit_and_offset(self, url, expected_pool_ids, session):
session.add_all(pools)
session.commit()
result = session.query(Pool).count()
- self.assertEqual(result, 121) # accounts for default pool as well
+ assert result == 121 # accounts for default pool as well
response = self.client.get(url, environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
pool_ids = [pool["name"] for pool in response.json["pools"]]
- self.assertEqual(pool_ids, expected_pool_ids)
+ assert pool_ids == expected_pool_ids
@provide_session
def test_should_respect_page_size_limit_default(self, session):
@@ -146,10 +143,10 @@ def test_should_respect_page_size_limit_default(self, session):
session.add_all(pools)
session.commit()
result = session.query(Pool).count()
- self.assertEqual(result, 121)
+ assert result == 121
response = self.client.get("/api/v1/pools", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(len(response.json['pools']), 100)
+ assert len(response.json['pools']) == 100
@provide_session
@conf_vars({("api", "maximum_page_limit"): "150"})
@@ -158,10 +155,10 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session):
session.add_all(pools)
session.commit()
result = session.query(Pool).count()
- self.assertEqual(result, 200)
+ assert result == 200
response = self.client.get("/api/v1/pools?limit=180", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(len(response.json['pools']), 150)
+ assert len(response.json['pools']) == 150
class TestGetPool(TestBasePoolEndpoints):
@@ -172,30 +169,24 @@ def test_response_200(self, session):
session.commit()
response = self.client.get("/api/v1/pools/test_pool_a", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(
- {
- "name": "test_pool_a",
- "slots": 3,
- "occupied_slots": 0,
- "running_slots": 0,
- "queued_slots": 0,
- "open_slots": 3,
- },
- response.json,
- )
+ assert {
+ "name": "test_pool_a",
+ "slots": 3,
+ "occupied_slots": 0,
+ "running_slots": 0,
+ "queued_slots": 0,
+ "open_slots": 3,
+ } == response.json
def test_response_404(self):
response = self.client.get("/api/v1/pools/invalid_pool", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 404
- self.assertEqual(
- {
- "detail": "Pool with name:'invalid_pool' not found",
- "status": 404,
- "title": "Not Found",
- "type": EXCEPTIONS_LINK_MAP[404],
- },
- response.json,
- )
+ assert {
+ "detail": "Pool with name:'invalid_pool' not found",
+ "status": 404,
+ "title": "Not Found",
+ "type": EXCEPTIONS_LINK_MAP[404],
+ } == response.json
def test_should_raises_401_unauthenticated(self):
response = self.client.get("/api/v1/pools/default_pool")
@@ -215,20 +206,17 @@ def test_response_204(self, session):
assert response.status_code == 204
# Check if the pool is deleted from the db
response = self.client.get(f"api/v1/pools/{pool_name}", environ_overrides={'REMOTE_USER': "test"})
- self.assertEqual(response.status_code, 404)
+ assert response.status_code == 404
def test_response_404(self):
response = self.client.delete("api/v1/pools/invalid_pool", environ_overrides={'REMOTE_USER': "test"})
- self.assertEqual(response.status_code, 404)
- self.assertEqual(
- {
- "detail": "Pool with name:'invalid_pool' not found",
- "status": 404,
- "title": "Not Found",
- "type": EXCEPTIONS_LINK_MAP[404],
- },
- response.json,
- )
+ assert response.status_code == 404
+ assert {
+ "detail": "Pool with name:'invalid_pool' not found",
+ "status": 404,
+ "title": "Not Found",
+ "type": EXCEPTIONS_LINK_MAP[404],
+ } == response.json
@provide_session
def test_should_raises_401_unauthenticated(self, session):
@@ -254,17 +242,14 @@ def test_response_200(self):
environ_overrides={'REMOTE_USER': "test"},
)
assert response.status_code == 200
- self.assertEqual(
- {
- "name": "test_pool_a",
- "slots": 3,
- "occupied_slots": 0,
- "running_slots": 0,
- "queued_slots": 0,
- "open_slots": 3,
- },
- response.json,
- )
+ assert {
+ "name": "test_pool_a",
+ "slots": 3,
+ "occupied_slots": 0,
+ "running_slots": 0,
+ "queued_slots": 0,
+ "open_slots": 3,
+ } == response.json
@provide_session
def test_response_409(self, session):
@@ -278,15 +263,12 @@ def test_response_409(self, session):
environ_overrides={'REMOTE_USER': "test"},
)
assert response.status_code == 409
- self.assertEqual(
- {
- "detail": f"Pool: {pool_name} already exists",
- "status": 409,
- "title": "Conflict",
- "type": EXCEPTIONS_LINK_MAP[409],
- },
- response.json,
- )
+ assert {
+ "detail": f"Pool: {pool_name} already exists",
+ "status": 409,
+ "title": "Conflict",
+ "type": EXCEPTIONS_LINK_MAP[409],
+ } == response.json
@parameterized.expand(
[
@@ -318,15 +300,12 @@ def test_response_400(self, name, request_json, error_detail):
"api/v1/pools", json=request_json, environ_overrides={'REMOTE_USER': "test"}
)
assert response.status_code == 400
- self.assertDictEqual(
- {
- "detail": error_detail,
- "status": 400,
- "title": "Bad Request",
- "type": EXCEPTIONS_LINK_MAP[400],
- },
- response.json,
- )
+ assert {
+ "detail": error_detail,
+ "status": 400,
+ "title": "Bad Request",
+ "type": EXCEPTIONS_LINK_MAP[400],
+ } == response.json
def test_should_raises_401_unauthenticated(self):
response = self.client.post("api/v1/pools", json={"name": "test_pool_a", "slots": 3})
@@ -345,18 +324,15 @@ def test_response_200(self, session):
json={"name": "test_pool_a", "slots": 3},
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(response.status_code, 200)
- self.assertEqual(
- {
- "occupied_slots": 0,
- "queued_slots": 0,
- "name": "test_pool_a",
- "open_slots": 3,
- "running_slots": 0,
- "slots": 3,
- },
- response.json,
- )
+ assert response.status_code == 200
+ assert {
+ "occupied_slots": 0,
+ "queued_slots": 0,
+ "name": "test_pool_a",
+ "open_slots": 3,
+ "running_slots": 0,
+ "slots": 3,
+ } == response.json
@parameterized.expand(
[
@@ -380,15 +356,12 @@ def test_response_400(self, error_detail, request_json, session):
"api/v1/pools/test_pool", json=request_json, environ_overrides={'REMOTE_USER': "test"}
)
assert response.status_code == 400
- self.assertEqual(
- {
- "detail": error_detail,
- "status": 400,
- "title": "Bad Request",
- "type": EXCEPTIONS_LINK_MAP[400],
- },
- response.json,
- )
+ assert {
+ "detail": error_detail,
+ "status": 400,
+ "title": "Bad Request",
+ "type": EXCEPTIONS_LINK_MAP[400],
+ } == response.json
@provide_session
def test_should_raises_401_unauthenticated(self, session):
@@ -408,15 +381,12 @@ class TestModifyDefaultPool(TestBasePoolEndpoints):
def test_delete_400(self):
response = self.client.delete("api/v1/pools/default_pool", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 400
- self.assertEqual(
- {
- "detail": "Default Pool can't be deleted",
- "status": 400,
- "title": "Bad Request",
- "type": EXCEPTIONS_LINK_MAP[400],
- },
- response.json,
- )
+ assert {
+ "detail": "Default Pool can't be deleted",
+ "status": 400,
+ "title": "Bad Request",
+ "type": EXCEPTIONS_LINK_MAP[400],
+ } == response.json
@parameterized.expand(
[
@@ -492,7 +462,7 @@ def test_patch(self, name, status_code, url, json, expected_response):
del name
response = self.client.patch(url, json=json, environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == status_code
- self.assertEqual(response.json, expected_response)
+ assert response.json == expected_response
class TestPatchPoolWithUpdateMask(TestBasePoolEndpoints):
@@ -531,17 +501,14 @@ def test_response_200(self, url, patch_json, expected_name, expected_slots, sess
session.commit()
response = self.client.patch(url, json=patch_json, environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(
- {
- "name": expected_name,
- "slots": expected_slots,
- "occupied_slots": 0,
- "running_slots": 0,
- "queued_slots": 0,
- "open_slots": expected_slots,
- },
- response.json,
- )
+ assert {
+ "name": expected_name,
+ "slots": expected_slots,
+ "occupied_slots": 0,
+ "running_slots": 0,
+ "queued_slots": 0,
+ "open_slots": expected_slots,
+ } == response.json
@parameterized.expand(
[
@@ -579,12 +546,9 @@ def test_response_400(self, name, error_detail, url, patch_json, session):
session.commit()
response = self.client.patch(url, json=patch_json, environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 400
- self.assertEqual(
- {
- "detail": error_detail,
- "status": 400,
- "title": "Bad Request",
- "type": EXCEPTIONS_LINK_MAP[400],
- },
- response.json,
- )
+ assert {
+ "detail": error_detail,
+ "status": 400,
+ "title": "Bad Request",
+ "type": EXCEPTIONS_LINK_MAP[400],
+ } == response.json
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index 2fb0dc05470b2..84c957fb1643d 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -142,32 +142,29 @@ def test_should_respond_200(self, session):
"/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context",
environ_overrides={"REMOTE_USER": "test"},
)
- self.assertEqual(response.status_code, 200)
- self.assertDictEqual(
- response.json,
- {
- "dag_id": "example_python_operator",
- "duration": 10000.0,
- "end_date": "2020-01-03T00:00:00+00:00",
- "execution_date": "2020-01-01T00:00:00+00:00",
- "executor_config": "{}",
- "hostname": "",
- "max_tries": 0,
- "operator": "PythonOperator",
- "pid": 100,
- "pool": "default_pool",
- "pool_slots": 1,
- "priority_weight": 6,
- "queue": "default_queue",
- "queued_when": None,
- "sla_miss": None,
- "start_date": "2020-01-02T00:00:00+00:00",
- "state": "running",
- "task_id": "print_the_context",
- "try_number": 0,
- "unixname": getpass.getuser(),
- },
- )
+ assert response.status_code == 200
+ assert response.json == {
+ "dag_id": "example_python_operator",
+ "duration": 10000.0,
+ "end_date": "2020-01-03T00:00:00+00:00",
+ "execution_date": "2020-01-01T00:00:00+00:00",
+ "executor_config": "{}",
+ "hostname": "",
+ "max_tries": 0,
+ "operator": "PythonOperator",
+ "pid": 100,
+ "pool": "default_pool",
+ "pool_slots": 1,
+ "priority_weight": 6,
+ "queue": "default_queue",
+ "queued_when": None,
+ "sla_miss": None,
+ "start_date": "2020-01-02T00:00:00+00:00",
+ "state": "running",
+ "task_id": "print_the_context",
+ "try_number": 0,
+ "unixname": getpass.getuser(),
+ }
@provide_session
def test_should_respond_200_task_instance_with_sla(self, session):
@@ -185,41 +182,38 @@ def test_should_respond_200_task_instance_with_sla(self, session):
"/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context",
environ_overrides={"REMOTE_USER": "test"},
)
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
- self.assertDictEqual(
- response.json,
- {
+ assert response.json == {
+ "dag_id": "example_python_operator",
+ "duration": 10000.0,
+ "end_date": "2020-01-03T00:00:00+00:00",
+ "execution_date": "2020-01-01T00:00:00+00:00",
+ "executor_config": "{}",
+ "hostname": "",
+ "max_tries": 0,
+ "operator": "PythonOperator",
+ "pid": 100,
+ "pool": "default_pool",
+ "pool_slots": 1,
+ "priority_weight": 6,
+ "queue": "default_queue",
+ "queued_when": None,
+ "sla_miss": {
"dag_id": "example_python_operator",
- "duration": 10000.0,
- "end_date": "2020-01-03T00:00:00+00:00",
+ "description": None,
+ "email_sent": False,
"execution_date": "2020-01-01T00:00:00+00:00",
- "executor_config": "{}",
- "hostname": "",
- "max_tries": 0,
- "operator": "PythonOperator",
- "pid": 100,
- "pool": "default_pool",
- "pool_slots": 1,
- "priority_weight": 6,
- "queue": "default_queue",
- "queued_when": None,
- "sla_miss": {
- "dag_id": "example_python_operator",
- "description": None,
- "email_sent": False,
- "execution_date": "2020-01-01T00:00:00+00:00",
- "notification_sent": False,
- "task_id": "print_the_context",
- "timestamp": "2020-01-01T00:00:00+00:00",
- },
- "start_date": "2020-01-02T00:00:00+00:00",
- "state": "running",
+ "notification_sent": False,
"task_id": "print_the_context",
- "try_number": 0,
- "unixname": getpass.getuser(),
+ "timestamp": "2020-01-01T00:00:00+00:00",
},
- )
+ "start_date": "2020-01-02T00:00:00+00:00",
+ "state": "running",
+ "task_id": "print_the_context",
+ "try_number": 0,
+ "unixname": getpass.getuser(),
+ }
def test_should_raises_401_unauthenticated(self):
response = self.client.get(
@@ -407,9 +401,9 @@ def test_should_respond_200(self, _, task_instances, update_extras, url, expecte
task_instances=task_instances,
)
response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"})
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.json["total_entries"], expected_ti)
- self.assertEqual(len(response.json["task_instances"]), expected_ti)
+ assert response.status_code == 200
+ assert response.json["total_entries"] == expected_ti
+ assert len(response.json["task_instances"]) == expected_ti
@provide_session
def test_should_respond_200_for_dag_id_filter(self, session):
@@ -420,10 +414,10 @@ def test_should_respond_200_for_dag_id_filter(self, session):
environ_overrides={"REMOTE_USER": "test"},
)
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
count = session.query(TaskInstance).filter(TaskInstance.dag_id == "example_python_operator").count()
- self.assertEqual(count, response.json["total_entries"])
- self.assertEqual(count, len(response.json["task_instances"]))
+ assert count == response.json["total_entries"]
+ assert count == len(response.json["task_instances"])
def test_should_raises_401_unauthenticated(self):
response = self.client.get(
@@ -555,9 +549,9 @@ def test_should_respond_200(
environ_overrides={"REMOTE_USER": "test"},
json=payload,
)
- self.assertEqual(response.status_code, 200, response.json)
- self.assertEqual(expected_ti_count, response.json["total_entries"])
- self.assertEqual(expected_ti_count, len(response.json["task_instances"]))
+ assert response.status_code == 200, response.json
+ assert expected_ti_count == response.json["total_entries"]
+ assert expected_ti_count == len(response.json["task_instances"])
@parameterized.expand(
[
@@ -595,9 +589,9 @@ def test_should_respond_200_when_task_instance_properties_are_none(
environ_overrides={"REMOTE_USER": "test"},
json=payload,
)
- self.assertEqual(response.status_code, 200, response.json)
- self.assertEqual(expected_ti_count, response.json["total_entries"])
- self.assertEqual(expected_ti_count, len(response.json["task_instances"]))
+ assert response.status_code == 200, response.json
+ assert expected_ti_count == response.json["total_entries"]
+ assert expected_ti_count == len(response.json["task_instances"])
@parameterized.expand(
[
@@ -618,9 +612,9 @@ def test_should_respond_200_dag_ids_filter(self, _, payload, expected_ti, total_
environ_overrides={"REMOTE_USER": "test"},
json=payload,
)
- self.assertEqual(response.status_code, 200)
- self.assertEqual(len(response.json["task_instances"]), expected_ti)
- self.assertEqual(response.json["total_entries"], total_ti)
+ assert response.status_code == 200
+ assert len(response.json["task_instances"]) == expected_ti
+ assert response.json["total_entries"] == total_ti
def test_should_raises_401_unauthenticated(self):
response = self.client.post(
@@ -815,8 +809,8 @@ def test_should_respond_200(
environ_overrides={"REMOTE_USER": "test"},
json=payload,
)
- self.assertEqual(response.status_code, 200)
- self.assertEqual(len(response.json["task_instances"]), expected_ti)
+ assert response.status_code == 200
+ assert len(response.json["task_instances"]) == expected_ti
@provide_session
def test_should_respond_200_with_reset_dag_run(self, session):
@@ -865,7 +859,7 @@ def test_should_respond_200_with_reset_dag_run(self, session):
failed_dag_runs = (
session.query(DagRun).filter(DagRun.state == "failed").count() # pylint: disable=W0143
)
- self.assertEqual(200, response.status_code)
+ assert 200 == response.status_code
expected_response = [
{
'dag_id': 'example_python_operator',
@@ -899,9 +893,9 @@ def test_should_respond_200_with_reset_dag_run(self, session):
},
]
for task_instance in expected_response:
- self.assertIn(task_instance, response.json["task_instances"])
- self.assertEqual(5, len(response.json["task_instances"]))
- self.assertEqual(0, failed_dag_runs, 0)
+ assert task_instance in response.json["task_instances"]
+ assert 5 == len(response.json["task_instances"])
+ assert 0 == failed_dag_runs, 0
def test_should_raises_401_unauthenticated(self):
response = self.client.post(
@@ -961,7 +955,7 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, se
json=payload,
)
assert response.status_code == 400
- self.assertEqual(response.json['detail'], expected)
+ assert response.json['detail'] == expected
class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
diff --git a/tests/api_connexion/endpoints/test_variable_endpoint.py b/tests/api_connexion/endpoints/test_variable_endpoint.py
index 5d8685f3770c8..0be9fd27ff6e7 100644
--- a/tests/api_connexion/endpoints/test_variable_endpoint.py
+++ b/tests/api_connexion/endpoints/test_variable_endpoint.py
@@ -182,7 +182,7 @@ def test_should_return_conf_max_if_req_max_above_conf(self):
Variable.set(f"var{i}", i)
response = self.client.get("/api/v1/variables?limit=180", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(len(response.json['variables']), 150)
+ assert len(response.json['variables']) == 150
def test_should_raises_401_unauthenticated(self):
Variable.set("var1", 1)
diff --git a/tests/api_connexion/endpoints/test_version_endpoint.py b/tests/api_connexion/endpoints/test_version_endpoint.py
index 072e1f6b3eaae..f046669c8edc6 100644
--- a/tests/api_connexion/endpoints/test_version_endpoint.py
+++ b/tests/api_connexion/endpoints/test_version_endpoint.py
@@ -36,6 +36,6 @@ def setUp(self) -> None:
def test_should_respond_200(self, mock_get_airflow_get_commit):
response = self.client.get("/api/v1/version")
- self.assertEqual(200, response.status_code)
- self.assertEqual({'git_version': 'GIT_COMMIT', 'version': 'MOCK_VERSION'}, response.json)
+ assert 200 == response.status_code
+ assert {'git_version': 'GIT_COMMIT', 'version': 'MOCK_VERSION'} == response.json
mock_get_airflow_get_commit.assert_called_once_with()
diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py
index 95ce05e1639ee..e38bf5f38f5c3 100644
--- a/tests/api_connexion/endpoints/test_xcom_endpoint.py
+++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py
@@ -103,21 +103,18 @@ def test_should_respond_200(self):
f"/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}",
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(200, response.status_code)
+ assert 200 == response.status_code
current_data = response.json
current_data['timestamp'] = 'TIMESTAMP'
- self.assertEqual(
- current_data,
- {
- 'dag_id': dag_id,
- 'execution_date': execution_date,
- 'key': xcom_key,
- 'task_id': task_id,
- 'timestamp': 'TIMESTAMP',
- 'value': 'TEST_VALUE',
- },
- )
+ assert current_data == {
+ 'dag_id': dag_id,
+ 'execution_date': execution_date,
+ 'key': xcom_key,
+ 'task_id': task_id,
+ 'timestamp': 'TIMESTAMP',
+ 'value': 'TEST_VALUE',
+ }
def test_should_raises_401_unauthenticated(self):
dag_id = 'test-dag-id'
@@ -181,32 +178,29 @@ def test_should_respond_200(self):
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(200, response.status_code)
+ assert 200 == response.status_code
response_data = response.json
for xcom_entry in response_data['xcom_entries']:
xcom_entry['timestamp'] = "TIMESTAMP"
- self.assertEqual(
- response.json,
- {
- 'xcom_entries': [
- {
- 'dag_id': dag_id,
- 'execution_date': execution_date,
- 'key': 'test-xcom-key-1',
- 'task_id': task_id,
- 'timestamp': "TIMESTAMP",
- },
- {
- 'dag_id': dag_id,
- 'execution_date': execution_date,
- 'key': 'test-xcom-key-2',
- 'task_id': task_id,
- 'timestamp': "TIMESTAMP",
- },
- ],
- 'total_entries': 2,
- },
- )
+ assert response.json == {
+ 'xcom_entries': [
+ {
+ 'dag_id': dag_id,
+ 'execution_date': execution_date,
+ 'key': 'test-xcom-key-1',
+ 'task_id': task_id,
+ 'timestamp': "TIMESTAMP",
+ },
+ {
+ 'dag_id': dag_id,
+ 'execution_date': execution_date,
+ 'key': 'test-xcom-key-2',
+ 'task_id': task_id,
+ 'timestamp': "TIMESTAMP",
+ },
+ ],
+ 'total_entries': 2,
+ }
def test_should_respond_200_with_tilde_and_access_to_all_dags(self):
dag_id_1 = 'test-dag-id-1'
@@ -227,46 +221,43 @@ def test_should_respond_200_with_tilde_and_access_to_all_dags(self):
environ_overrides={'REMOTE_USER': "test"},
)
- self.assertEqual(200, response.status_code)
+ assert 200 == response.status_code
response_data = response.json
for xcom_entry in response_data['xcom_entries']:
xcom_entry['timestamp'] = "TIMESTAMP"
- self.assertEqual(
- response.json,
- {
- 'xcom_entries': [
- {
- 'dag_id': dag_id_1,
- 'execution_date': execution_date,
- 'key': 'test-xcom-key-1',
- 'task_id': task_id_1,
- 'timestamp': "TIMESTAMP",
- },
- {
- 'dag_id': dag_id_1,
- 'execution_date': execution_date,
- 'key': 'test-xcom-key-2',
- 'task_id': task_id_1,
- 'timestamp': "TIMESTAMP",
- },
- {
- 'dag_id': dag_id_2,
- 'execution_date': execution_date,
- 'key': 'test-xcom-key-1',
- 'task_id': task_id_2,
- 'timestamp': "TIMESTAMP",
- },
- {
- 'dag_id': dag_id_2,
- 'execution_date': execution_date,
- 'key': 'test-xcom-key-2',
- 'task_id': task_id_2,
- 'timestamp': "TIMESTAMP",
- },
- ],
- 'total_entries': 4,
- },
- )
+ assert response.json == {
+ 'xcom_entries': [
+ {
+ 'dag_id': dag_id_1,
+ 'execution_date': execution_date,
+ 'key': 'test-xcom-key-1',
+ 'task_id': task_id_1,
+ 'timestamp': "TIMESTAMP",
+ },
+ {
+ 'dag_id': dag_id_1,
+ 'execution_date': execution_date,
+ 'key': 'test-xcom-key-2',
+ 'task_id': task_id_1,
+ 'timestamp': "TIMESTAMP",
+ },
+ {
+ 'dag_id': dag_id_2,
+ 'execution_date': execution_date,
+ 'key': 'test-xcom-key-1',
+ 'task_id': task_id_2,
+ 'timestamp': "TIMESTAMP",
+ },
+ {
+ 'dag_id': dag_id_2,
+ 'execution_date': execution_date,
+ 'key': 'test-xcom-key-2',
+ 'task_id': task_id_2,
+ 'timestamp': "TIMESTAMP",
+ },
+ ],
+ 'total_entries': 4,
+ }
def test_should_respond_200_with_tilde_and_granular_dag_access(self):
dag_id_1 = 'test-dag-id-1'
@@ -286,32 +277,29 @@ def test_should_respond_200_with_tilde_and_granular_dag_access(self):
environ_overrides={'REMOTE_USER': "test_granular_permissions"},
)
- self.assertEqual(200, response.status_code)
+ assert 200 == response.status_code
response_data = response.json
for xcom_entry in response_data['xcom_entries']:
xcom_entry['timestamp'] = "TIMESTAMP"
- self.assertEqual(
- response.json,
- {
- 'xcom_entries': [
- {
- 'dag_id': dag_id_1,
- 'execution_date': execution_date,
- 'key': 'test-xcom-key-1',
- 'task_id': task_id_1,
- 'timestamp': "TIMESTAMP",
- },
- {
- 'dag_id': dag_id_1,
- 'execution_date': execution_date,
- 'key': 'test-xcom-key-2',
- 'task_id': task_id_1,
- 'timestamp': "TIMESTAMP",
- },
- ],
- 'total_entries': 2,
- },
- )
+ assert response.json == {
+ 'xcom_entries': [
+ {
+ 'dag_id': dag_id_1,
+ 'execution_date': execution_date,
+ 'key': 'test-xcom-key-1',
+ 'task_id': task_id_1,
+ 'timestamp': "TIMESTAMP",
+ },
+ {
+ 'dag_id': dag_id_1,
+ 'execution_date': execution_date,
+ 'key': 'test-xcom-key-2',
+ 'task_id': task_id_1,
+ 'timestamp': "TIMESTAMP",
+ },
+ ],
+ 'total_entries': 2,
+ }
def test_should_raises_401_unauthenticated(self):
dag_id = 'test-dag-id'
@@ -461,9 +449,9 @@ def test_handle_limit_offset(self, query_params, expected_xcom_ids, session):
session.commit()
response = self.client.get(url, environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
- self.assertEqual(response.json["total_entries"], 10)
+ assert response.json["total_entries"] == 10
conn_ids = [conn["key"] for conn in response.json["xcom_entries"] if conn]
- self.assertEqual(conn_ids, expected_xcom_ids)
+ assert conn_ids == expected_xcom_ids
def _create_xcoms(self, count):
return [
diff --git a/tests/api_connexion/schemas/test_common_schema.py b/tests/api_connexion/schemas/test_common_schema.py
index c6cecdb27a3f8..c734483a583f0 100644
--- a/tests/api_connexion/schemas/test_common_schema.py
+++ b/tests/api_connexion/schemas/test_common_schema.py
@@ -18,6 +18,7 @@
import datetime
import unittest
+import pytest
from dateutil import relativedelta
from airflow.api_connexion.schemas.common_schema import (
@@ -34,14 +35,14 @@ def test_should_serialize(self):
instance = datetime.timedelta(days=12)
schema_instance = TimeDeltaSchema()
result = schema_instance.dump(instance)
- self.assertEqual({"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0}, result)
+ assert {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0} == result
def test_should_deserialize(self):
instance = {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0}
schema_instance = TimeDeltaSchema()
result = schema_instance.load(instance)
expected_instance = datetime.timedelta(days=12)
- self.assertEqual(expected_instance, result)
+ assert expected_instance == result
class TestRelativeDeltaSchema(unittest.TestCase):
@@ -49,34 +50,31 @@ def test_should_serialize(self):
instance = relativedelta.relativedelta(days=+12)
schema_instance = RelativeDeltaSchema()
result = schema_instance.dump(instance)
- self.assertEqual(
- {
- '__type': 'RelativeDelta',
- "day": None,
- "days": 12,
- "hour": None,
- "hours": 0,
- "leapdays": 0,
- "microsecond": None,
- "microseconds": 0,
- "minute": None,
- "minutes": 0,
- "month": None,
- "months": 0,
- "second": None,
- "seconds": 0,
- "year": None,
- "years": 0,
- },
- result,
- )
+ assert {
+ '__type': 'RelativeDelta',
+ "day": None,
+ "days": 12,
+ "hour": None,
+ "hours": 0,
+ "leapdays": 0,
+ "microsecond": None,
+ "microseconds": 0,
+ "minute": None,
+ "minutes": 0,
+ "month": None,
+ "months": 0,
+ "second": None,
+ "seconds": 0,
+ "year": None,
+ "years": 0,
+ } == result
def test_should_deserialize(self):
instance = {"__type": "RelativeDelta", "days": 12, "seconds": 0}
schema_instance = RelativeDeltaSchema()
result = schema_instance.load(instance)
expected_instance = relativedelta.relativedelta(days=+12)
- self.assertEqual(expected_instance, result)
+ assert expected_instance == result
class TestCronExpressionSchema(unittest.TestCase):
@@ -85,7 +83,7 @@ def test_should_deserialize(self):
schema_instance = CronExpressionSchema()
result = schema_instance.load(instance)
expected_instance = CronExpression("5 4 * * *")
- self.assertEqual(expected_instance, result)
+ assert expected_instance == result
class TestScheduleIntervalSchema(unittest.TestCase):
@@ -93,57 +91,54 @@ def test_should_serialize_timedelta(self):
instance = datetime.timedelta(days=12)
schema_instance = ScheduleIntervalSchema()
result = schema_instance.dump(instance)
- self.assertEqual({"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0}, result)
+ assert {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0} == result
def test_should_deserialize_timedelta(self):
instance = {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0}
schema_instance = ScheduleIntervalSchema()
result = schema_instance.load(instance)
expected_instance = datetime.timedelta(days=12)
- self.assertEqual(expected_instance, result)
+ assert expected_instance == result
def test_should_serialize_relative_delta(self):
instance = relativedelta.relativedelta(days=+12)
schema_instance = ScheduleIntervalSchema()
result = schema_instance.dump(instance)
- self.assertEqual(
- {
- "__type": "RelativeDelta",
- "day": None,
- "days": 12,
- "hour": None,
- "hours": 0,
- "leapdays": 0,
- "microsecond": None,
- "microseconds": 0,
- "minute": None,
- "minutes": 0,
- "month": None,
- "months": 0,
- "second": None,
- "seconds": 0,
- "year": None,
- "years": 0,
- },
- result,
- )
+ assert {
+ "__type": "RelativeDelta",
+ "day": None,
+ "days": 12,
+ "hour": None,
+ "hours": 0,
+ "leapdays": 0,
+ "microsecond": None,
+ "microseconds": 0,
+ "minute": None,
+ "minutes": 0,
+ "month": None,
+ "months": 0,
+ "second": None,
+ "seconds": 0,
+ "year": None,
+ "years": 0,
+ } == result
def test_should_deserialize_relative_delta(self):
instance = {"__type": "RelativeDelta", "days": 12, "seconds": 0}
schema_instance = ScheduleIntervalSchema()
result = schema_instance.load(instance)
expected_instance = relativedelta.relativedelta(days=+12)
- self.assertEqual(expected_instance, result)
+ assert expected_instance == result
def test_should_serialize_cron_expression(self):
instance = "5 4 * * *"
schema_instance = ScheduleIntervalSchema()
result = schema_instance.dump(instance)
expected_instance = {"__type": "CronExpression", "value": "5 4 * * *"}
- self.assertEqual(expected_instance, result)
+ assert expected_instance == result
def test_should_error_unknown_obj_type(self):
instance = 342
schema_instance = ScheduleIntervalSchema()
- with self.assertRaisesRegex(Exception, "Unknown object type: int"):
+ with pytest.raises(Exception, match="Unknown object type: int"):
schema_instance.dump(instance)
diff --git a/tests/api_connexion/schemas/test_connection_schema.py b/tests/api_connexion/schemas/test_connection_schema.py
index 5a4c580099a21..983a735719b62 100644
--- a/tests/api_connexion/schemas/test_connection_schema.py
+++ b/tests/api_connexion/schemas/test_connection_schema.py
@@ -18,6 +18,7 @@
import unittest
import marshmallow
+import pytest
from airflow.api_connexion.schemas.connection_schema import (
ConnectionCollection,
@@ -52,17 +53,14 @@ def test_serialize(self, session):
session.commit()
connection_model = session.query(Connection).first()
deserialized_connection = connection_collection_item_schema.dump(connection_model)
- self.assertEqual(
- deserialized_connection,
- {
- 'connection_id': "mysql_default",
- 'conn_type': 'mysql',
- 'host': 'mysql',
- 'login': 'login',
- 'schema': 'testschema',
- 'port': 80,
- },
- )
+ assert deserialized_connection == {
+ 'connection_id': "mysql_default",
+ 'conn_type': 'mysql',
+ 'host': 'mysql',
+ 'login': 'login',
+ 'schema': 'testschema',
+ 'port': 80,
+ }
def test_deserialize(self):
connection_dump_1 = {
@@ -80,32 +78,26 @@ def test_deserialize(self):
result_1 = connection_collection_item_schema.load(connection_dump_1)
result_2 = connection_collection_item_schema.load(connection_dump_2)
- self.assertEqual(
- result_1,
- {
- 'conn_id': "mysql_default_1",
- 'conn_type': 'mysql',
- 'host': 'mysql',
- 'login': 'login',
- 'schema': 'testschema',
- 'port': 80,
- },
- )
- self.assertEqual(
- result_2,
- {
- 'conn_id': "mysql_default_2",
- 'conn_type': "postgres",
- },
- )
+ assert result_1 == {
+ 'conn_id': "mysql_default_1",
+ 'conn_type': 'mysql',
+ 'host': 'mysql',
+ 'login': 'login',
+ 'schema': 'testschema',
+ 'port': 80,
+ }
+ assert result_2 == {
+ 'conn_id': "mysql_default_2",
+ 'conn_type': "postgres",
+ }
def test_deserialize_required_fields(self):
connection_dump_1 = {
'connection_id': "mysql_default_2",
}
- with self.assertRaisesRegex(
+ with pytest.raises(
marshmallow.exceptions.ValidationError,
- re.escape("{'conn_type': ['Missing data for required field.']}"),
+ match=re.escape("{'conn_type': ['Missing data for required field.']}"),
):
connection_collection_item_schema.load(connection_dump_1)
@@ -127,30 +119,27 @@ def test_serialize(self, session):
session.commit()
instance = ConnectionCollection(connections=connections, total_entries=2)
deserialized_connections = connection_collection_schema.dump(instance)
- self.assertEqual(
- deserialized_connections,
- {
- 'connections': [
- {
- "connection_id": "mysql_default_1",
- "conn_type": "test-type",
- "host": None,
- "login": None,
- 'schema': None,
- 'port': None,
- },
- {
- "connection_id": "mysql_default_2",
- "conn_type": "test-type2",
- "host": None,
- "login": None,
- 'schema': None,
- 'port': None,
- },
- ],
- 'total_entries': 2,
- },
- )
+ assert deserialized_connections == {
+ 'connections': [
+ {
+ "connection_id": "mysql_default_1",
+ "conn_type": "test-type",
+ "host": None,
+ "login": None,
+ 'schema': None,
+ 'port': None,
+ },
+ {
+ "connection_id": "mysql_default_2",
+ "conn_type": "test-type2",
+ "host": None,
+ "login": None,
+ 'schema': None,
+ 'port': None,
+ },
+ ],
+ 'total_entries': 2,
+ }
class TestConnectionSchema(unittest.TestCase):
@@ -177,18 +166,15 @@ def test_serialize(self, session):
session.commit()
connection_model = session.query(Connection).first()
deserialized_connection = connection_schema.dump(connection_model)
- self.assertEqual(
- deserialized_connection,
- {
- 'connection_id': "mysql_default",
- 'conn_type': 'mysql',
- 'host': 'mysql',
- 'login': 'login',
- 'schema': 'testschema',
- 'port': 80,
- 'extra': "{'key':'string'}",
- },
- )
+ assert deserialized_connection == {
+ 'connection_id': "mysql_default",
+ 'conn_type': 'mysql',
+ 'host': 'mysql',
+ 'login': 'login',
+ 'schema': 'testschema',
+ 'port': 80,
+ 'extra': "{'key':'string'}",
+ }
def test_deserialize(self):
den = {
@@ -201,15 +187,12 @@ def test_deserialize(self):
'extra': "{'key':'string'}",
}
result = connection_schema.load(den)
- self.assertEqual(
- result,
- {
- 'conn_id': "mysql_default",
- 'conn_type': 'mysql',
- 'host': 'mysql',
- 'login': 'login',
- 'schema': 'testschema',
- 'port': 80,
- 'extra': "{'key':'string'}",
- },
- )
+ assert result == {
+ 'conn_id': "mysql_default",
+ 'conn_type': 'mysql',
+ 'host': 'mysql',
+ 'login': 'login',
+ 'schema': 'testschema',
+ 'port': 80,
+ 'extra': "{'key':'string'}",
+ }
diff --git a/tests/api_connexion/schemas/test_dag_run_schema.py b/tests/api_connexion/schemas/test_dag_run_schema.py
index 5569dbc5bb8eb..3e6bf2eb071a9 100644
--- a/tests/api_connexion/schemas/test_dag_run_schema.py
+++ b/tests/api_connexion/schemas/test_dag_run_schema.py
@@ -16,6 +16,7 @@
# under the License.
import unittest
+import pytest
from dateutil.parser import parse
from parameterized import parameterized
@@ -58,19 +59,16 @@ def test_serialize(self, session):
dagrun_model = session.query(DagRun).first()
deserialized_dagrun = dagrun_schema.dump(dagrun_model)
- self.assertEqual(
- deserialized_dagrun,
- {
- "dag_id": None,
- "dag_run_id": "my-dag-run",
- "end_date": None,
- "state": "running",
- "execution_date": self.default_time,
- "external_trigger": True,
- "start_date": self.default_time,
- "conf": {"start": "stop"},
- },
- )
+ assert deserialized_dagrun == {
+ "dag_id": None,
+ "dag_run_id": "my-dag-run",
+ "end_date": None,
+ "state": "running",
+ "execution_date": self.default_time,
+ "external_trigger": True,
+ "start_date": self.default_time,
+ "conf": {"start": "stop"},
+ }
@parameterized.expand(
[
@@ -106,22 +104,19 @@ def test_serialize(self, session):
)
def test_deserialize(self, serialized_dagrun, expected_result):
result = dagrun_schema.load(serialized_dagrun)
- self.assertDictEqual(result, expected_result)
+ assert result == expected_result
def test_autofill_fields(self):
"""Dag_run_id and execution_date fields are autogenerated if missing"""
serialized_dagrun = {}
result = dagrun_schema.load(serialized_dagrun)
- self.assertDictEqual(
- result,
- {"execution_date": result["execution_date"], "run_id": result["run_id"]},
- )
+ assert result == {"execution_date": result["execution_date"], "run_id": result["run_id"]}
def test_invalid_execution_date_raises(self):
serialized_dagrun = {"execution_date": "mydate"}
- with self.assertRaises(BadRequest) as e:
+ with pytest.raises(BadRequest) as ctx:
dagrun_schema.load(serialized_dagrun)
- self.assertEqual(str(e.exception), "Incorrect datetime argument")
+ assert str(ctx.value) == "Incorrect datetime argument"
class TestDagRunCollection(TestDAGRunBase):
@@ -145,31 +140,28 @@ def test_serialize(self, session):
session.commit()
instance = DAGRunCollection(dag_runs=dagruns, total_entries=2)
deserialized_dagruns = dagrun_collection_schema.dump(instance)
- self.assertEqual(
- deserialized_dagruns,
- {
- "dag_runs": [
- {
- "dag_id": None,
- "dag_run_id": "my-dag-run",
- "end_date": None,
- "execution_date": self.default_time,
- "external_trigger": True,
- "state": "running",
- "start_date": self.default_time,
- "conf": {"start": "stop"},
- },
- {
- "dag_id": None,
- "dag_run_id": "my-dag-run-2",
- "end_date": None,
- "state": "running",
- "execution_date": self.default_time,
- "external_trigger": True,
- "start_date": self.default_time,
- "conf": {},
- },
- ],
- "total_entries": 2,
- },
- )
+ assert deserialized_dagruns == {
+ "dag_runs": [
+ {
+ "dag_id": None,
+ "dag_run_id": "my-dag-run",
+ "end_date": None,
+ "execution_date": self.default_time,
+ "external_trigger": True,
+ "state": "running",
+ "start_date": self.default_time,
+ "conf": {"start": "stop"},
+ },
+ {
+ "dag_id": None,
+ "dag_run_id": "my-dag-run-2",
+ "end_date": None,
+ "state": "running",
+ "execution_date": self.default_time,
+ "external_trigger": True,
+ "start_date": self.default_time,
+ "conf": {},
+ },
+ ],
+ "total_entries": 2,
+ }
diff --git a/tests/api_connexion/schemas/test_dag_schema.py b/tests/api_connexion/schemas/test_dag_schema.py
index bc18b3ced6974..4b6379550c7c6 100644
--- a/tests/api_connexion/schemas/test_dag_schema.py
+++ b/tests/api_connexion/schemas/test_dag_schema.py
@@ -47,21 +47,18 @@ def test_serialize(self):
tags=[DagTag(name="tag-1"), DagTag(name="tag-2")],
)
serialized_dag = DAGSchema().dump(dag_model)
- self.assertEqual(
- {
- "dag_id": "test_dag_id",
- "description": "The description",
- "fileloc": "/root/airflow/dags/my_dag.py",
- "file_token": SERIALIZER.dumps("/root/airflow/dags/my_dag.py"),
- "is_paused": True,
- "is_subdag": False,
- "owners": ["airflow1", "airflow2"],
- "root_dag_id": "test_root_dag_id",
- "schedule_interval": {"__type": "CronExpression", "value": "5 4 * * *"},
- "tags": [{"name": "tag-1"}, {"name": "tag-2"}],
- },
- serialized_dag,
- )
+ assert {
+ "dag_id": "test_dag_id",
+ "description": "The description",
+ "fileloc": "/root/airflow/dags/my_dag.py",
+ "file_token": SERIALIZER.dumps("/root/airflow/dags/my_dag.py"),
+ "is_paused": True,
+ "is_subdag": False,
+ "owners": ["airflow1", "airflow2"],
+ "root_dag_id": "test_root_dag_id",
+ "schedule_interval": {"__type": "CronExpression", "value": "5 4 * * *"},
+ "tags": [{"name": "tag-1"}, {"name": "tag-2"}],
+ } == serialized_dag
class TestDAGCollectionSchema(unittest.TestCase):
@@ -70,38 +67,35 @@ def test_serialize(self):
dag_model_b = DagModel(dag_id="test_dag_id_b", fileloc="/tmp/a.py")
schema = DAGCollectionSchema()
instance = DAGCollection(dags=[dag_model_a, dag_model_b], total_entries=2)
- self.assertEqual(
- {
- "dags": [
- {
- "dag_id": "test_dag_id_a",
- "description": None,
- "fileloc": "/tmp/a.py",
- "file_token": SERIALIZER.dumps("/tmp/a.py"),
- "is_paused": None,
- "is_subdag": None,
- "owners": [],
- "root_dag_id": None,
- "schedule_interval": None,
- "tags": [],
- },
- {
- "dag_id": "test_dag_id_b",
- "description": None,
- "fileloc": "/tmp/a.py",
- "file_token": SERIALIZER.dumps("/tmp/a.py"),
- "is_paused": None,
- "is_subdag": None,
- "owners": [],
- "root_dag_id": None,
- "schedule_interval": None,
- "tags": [],
- },
- ],
- "total_entries": 2,
- },
- schema.dump(instance),
- )
+ assert {
+ "dags": [
+ {
+ "dag_id": "test_dag_id_a",
+ "description": None,
+ "fileloc": "/tmp/a.py",
+ "file_token": SERIALIZER.dumps("/tmp/a.py"),
+ "is_paused": None,
+ "is_subdag": None,
+ "owners": [],
+ "root_dag_id": None,
+ "schedule_interval": None,
+ "tags": [],
+ },
+ {
+ "dag_id": "test_dag_id_b",
+ "description": None,
+ "fileloc": "/tmp/a.py",
+ "file_token": SERIALIZER.dumps("/tmp/a.py"),
+ "is_paused": None,
+ "is_subdag": None,
+ "owners": [],
+ "root_dag_id": None,
+ "schedule_interval": None,
+ "tags": [],
+ },
+ ],
+ "total_entries": 2,
+ } == schema.dump(instance)
class TestDAGDetailSchema:
diff --git a/tests/api_connexion/schemas/test_error_schema.py b/tests/api_connexion/schemas/test_error_schema.py
index c2f6aef24cbcb..02d574f10403f 100644
--- a/tests/api_connexion/schemas/test_error_schema.py
+++ b/tests/api_connexion/schemas/test_error_schema.py
@@ -48,15 +48,12 @@ def test_serialize(self, session):
session.commit()
serialized_data = import_error_schema.dump(import_error)
serialized_data["import_error_id"] = 1
- self.assertEqual(
- {
- "filename": "lorem.py",
- "import_error_id": 1,
- "stack_trace": "Lorem Ipsum",
- "timestamp": "2020-06-10T12:02:44+00:00",
- },
- serialized_data,
- )
+ assert {
+ "filename": "lorem.py",
+ "import_error_id": 1,
+ "stack_trace": "Lorem Ipsum",
+ "timestamp": "2020-06-10T12:02:44+00:00",
+ } == serialized_data
class TestErrorCollectionSchema(TestErrorSchemaBase):
@@ -80,23 +77,20 @@ def test_serialize(self, session):
# To maintain consistency in the key sequence across the db in tests
serialized_data["import_errors"][0]["import_error_id"] = 1
serialized_data["import_errors"][1]["import_error_id"] = 2
- self.assertEqual(
- {
- "import_errors": [
- {
- "filename": "Lorem_ipsum.py",
- "import_error_id": 1,
- "stack_trace": "Lorem ipsum",
- "timestamp": "2020-06-10T12:02:44+00:00",
- },
- {
- "filename": "Lorem_ipsum.py",
- "import_error_id": 2,
- "stack_trace": "Lorem ipsum",
- "timestamp": "2020-06-10T12:02:44+00:00",
- },
- ],
- "total_entries": 2,
- },
- serialized_data,
- )
+ assert {
+ "import_errors": [
+ {
+ "filename": "Lorem_ipsum.py",
+ "import_error_id": 1,
+ "stack_trace": "Lorem ipsum",
+ "timestamp": "2020-06-10T12:02:44+00:00",
+ },
+ {
+ "filename": "Lorem_ipsum.py",
+ "import_error_id": 2,
+ "stack_trace": "Lorem ipsum",
+ "timestamp": "2020-06-10T12:02:44+00:00",
+ },
+ ],
+ "total_entries": 2,
+ } == serialized_data
diff --git a/tests/api_connexion/schemas/test_event_log_schema.py b/tests/api_connexion/schemas/test_event_log_schema.py
index b4c2003528a4b..597ecc71b61f2 100644
--- a/tests/api_connexion/schemas/test_event_log_schema.py
+++ b/tests/api_connexion/schemas/test_event_log_schema.py
@@ -59,19 +59,16 @@ def test_serialize(self, session):
event_log_model.dttm = timezone.parse(self.default_time)
log_model = session.query(Log).first()
deserialized_log = event_log_schema.dump(log_model)
- self.assertEqual(
- deserialized_log,
- {
- "event_log_id": event_log_model.id,
- "event": "TEST_EVENT",
- "dag_id": "TEST_DAG_ID",
- "task_id": "TEST_TASK_ID",
- "execution_date": self.default_time,
- "owner": 'airflow',
- "when": self.default_time,
- "extra": None,
- },
- )
+ assert deserialized_log == {
+ "event_log_id": event_log_model.id,
+ "event": "TEST_EVENT",
+ "dag_id": "TEST_DAG_ID",
+ "task_id": "TEST_TASK_ID",
+ "execution_date": self.default_time,
+ "owner": 'airflow',
+ "when": self.default_time,
+ "extra": None,
+ }
class TestEventLogCollection(TestEventLogSchemaBase):
@@ -86,31 +83,28 @@ def test_serialize(self, session):
event_log_model_2.dttm = timezone.parse(self.default_time2)
instance = EventLogCollection(event_logs=event_logs, total_entries=2)
deserialized_event_logs = event_log_collection_schema.dump(instance)
- self.assertEqual(
- deserialized_event_logs,
- {
- "event_logs": [
- {
- "event_log_id": event_log_model_1.id,
- "event": "TEST_EVENT_1",
- "dag_id": "TEST_DAG_ID",
- "task_id": "TEST_TASK_ID",
- "execution_date": self.default_time,
- "owner": 'airflow',
- "when": self.default_time,
- "extra": None,
- },
- {
- "event_log_id": event_log_model_2.id,
- "event": "TEST_EVENT_2",
- "dag_id": "TEST_DAG_ID",
- "task_id": "TEST_TASK_ID",
- "execution_date": self.default_time,
- "owner": 'airflow',
- "when": self.default_time2,
- "extra": None,
- },
- ],
- "total_entries": 2,
- },
- )
+ assert deserialized_event_logs == {
+ "event_logs": [
+ {
+ "event_log_id": event_log_model_1.id,
+ "event": "TEST_EVENT_1",
+ "dag_id": "TEST_DAG_ID",
+ "task_id": "TEST_TASK_ID",
+ "execution_date": self.default_time,
+ "owner": 'airflow',
+ "when": self.default_time,
+ "extra": None,
+ },
+ {
+ "event_log_id": event_log_model_2.id,
+ "event": "TEST_EVENT_2",
+ "dag_id": "TEST_DAG_ID",
+ "task_id": "TEST_TASK_ID",
+ "execution_date": self.default_time,
+ "owner": 'airflow',
+ "when": self.default_time2,
+ "extra": None,
+ },
+ ],
+ "total_entries": 2,
+ }
diff --git a/tests/api_connexion/schemas/test_health_schema.py b/tests/api_connexion/schemas/test_health_schema.py
index e7e1ff6336efc..339da2cc575da 100644
--- a/tests/api_connexion/schemas/test_health_schema.py
+++ b/tests/api_connexion/schemas/test_health_schema.py
@@ -32,4 +32,4 @@ def test_serialize(self):
},
}
serialized_data = health_schema.dump(payload)
- self.assertDictEqual(serialized_data, payload)
+ assert serialized_data == payload
diff --git a/tests/api_connexion/schemas/test_pool_schemas.py b/tests/api_connexion/schemas/test_pool_schemas.py
index 53b69633df852..d1cb2c49c658d 100644
--- a/tests/api_connexion/schemas/test_pool_schemas.py
+++ b/tests/api_connexion/schemas/test_pool_schemas.py
@@ -37,23 +37,20 @@ def test_serialize(self, session):
session.commit()
pool_instance = session.query(Pool).filter(Pool.pool == pool_model.pool).first()
serialized_pool = pool_schema.dump(pool_instance)
- self.assertEqual(
- serialized_pool,
- {
- "name": "test_pool",
- "slots": 2,
- "occupied_slots": 0,
- "running_slots": 0,
- "queued_slots": 0,
- "open_slots": 2,
- },
- )
+ assert serialized_pool == {
+ "name": "test_pool",
+ "slots": 2,
+ "occupied_slots": 0,
+ "running_slots": 0,
+ "queued_slots": 0,
+ "open_slots": 2,
+ }
@provide_session
def test_deserialize(self, session):
pool_dict = {"name": "test_pool", "slots": 3}
deserialized_pool = pool_schema.load(pool_dict, session=session)
- self.assertNotIsInstance(deserialized_pool, Pool) # Checks if load_instance is set to True
+ assert not isinstance(deserialized_pool, Pool) # Checks if load_instance is set to True
class TestPoolCollectionSchema(unittest.TestCase):
@@ -67,27 +64,24 @@ def test_serialize(self):
pool_model_a = Pool(pool="test_pool_a", slots=3)
pool_model_b = Pool(pool="test_pool_b", slots=3)
instance = PoolCollection(pools=[pool_model_a, pool_model_b], total_entries=2)
- self.assertEqual(
- {
- "pools": [
- {
- "name": "test_pool_a",
- "slots": 3,
- "occupied_slots": 0,
- "running_slots": 0,
- "queued_slots": 0,
- "open_slots": 3,
- },
- {
- "name": "test_pool_b",
- "slots": 3,
- "occupied_slots": 0,
- "running_slots": 0,
- "queued_slots": 0,
- "open_slots": 3,
- },
- ],
- "total_entries": 2,
- },
- pool_collection_schema.dump(instance),
- )
+ assert {
+ "pools": [
+ {
+ "name": "test_pool_a",
+ "slots": 3,
+ "occupied_slots": 0,
+ "running_slots": 0,
+ "queued_slots": 0,
+ "open_slots": 3,
+ },
+ {
+ "name": "test_pool_b",
+ "slots": 3,
+ "occupied_slots": 0,
+ "running_slots": 0,
+ "queued_slots": 0,
+ "open_slots": 3,
+ },
+ ],
+ "total_entries": 2,
+ } == pool_collection_schema.dump(instance)
diff --git a/tests/api_connexion/schemas/test_task_instance_schema.py b/tests/api_connexion/schemas/test_task_instance_schema.py
index 88ee16c38bdd9..9720ac0c8d9c6 100644
--- a/tests/api_connexion/schemas/test_task_instance_schema.py
+++ b/tests/api_connexion/schemas/test_task_instance_schema.py
@@ -19,6 +19,7 @@
import getpass
import unittest
+import pytest
from marshmallow import ValidationError
from parameterized import parameterized
@@ -88,7 +89,7 @@ def test_task_instance_schema_without_sla(self, session):
"try_number": 0,
"unixname": getpass.getuser(),
}
- self.assertDictEqual(serialized_ti, expected_json)
+ assert serialized_ti == expected_json
@provide_session
def test_task_instance_schema_with_sla(self, session):
@@ -134,7 +135,7 @@ def test_task_instance_schema_with_sla(self, session):
"try_number": 0,
"unixname": getpass.getuser(),
}
- self.assertDictEqual(serialized_ti, expected_json)
+ assert serialized_ti == expected_json
class TestClearTaskInstanceFormSchema(unittest.TestCase):
@@ -163,7 +164,7 @@ class TestClearTaskInstanceFormSchema(unittest.TestCase):
]
)
def test_validation_error(self, payload):
- with self.assertRaises(ValidationError):
+ with pytest.raises(ValidationError):
clear_task_instance_form.load(payload)
@@ -193,7 +194,7 @@ def test_success(self):
'new_state': 'failed',
'task_id': 'print_the_context',
}
- self.assertEqual(expected_result, result)
+ assert expected_result == result
@parameterized.expand(
[
@@ -206,5 +207,5 @@ def test_success(self):
def test_validation_error(self, override_data):
self.current_input.update(override_data)
- with self.assertRaises(ValidationError):
+ with pytest.raises(ValidationError):
clear_task_instance_form.load(self.current_input)
diff --git a/tests/api_connexion/schemas/test_version_schema.py b/tests/api_connexion/schemas/test_version_schema.py
index 2705910d4c38d..8cb654b1da617 100644
--- a/tests/api_connexion/schemas/test_version_schema.py
+++ b/tests/api_connexion/schemas/test_version_schema.py
@@ -35,4 +35,4 @@ def test_serialize(self, git_commit):
current_data = version_info_schema.dump(version_info)
expected_result = {'version': 'VERSION', 'git_version': git_commit}
- self.assertEqual(expected_result, current_data)
+ assert expected_result == current_data
diff --git a/tests/api_connexion/schemas/test_xcom_schema.py b/tests/api_connexion/schemas/test_xcom_schema.py
index 846f72729300a..b541ebdf7fae3 100644
--- a/tests/api_connexion/schemas/test_xcom_schema.py
+++ b/tests/api_connexion/schemas/test_xcom_schema.py
@@ -64,16 +64,13 @@ def test_serialize(self, session):
session.commit()
xcom_model = session.query(XCom).first()
deserialized_xcom = xcom_collection_item_schema.dump(xcom_model)
- self.assertEqual(
- deserialized_xcom,
- {
- 'key': 'test_key',
- 'timestamp': self.default_time,
- 'execution_date': self.default_time,
- 'task_id': 'test_task_id',
- 'dag_id': 'test_dag',
- },
- )
+ assert deserialized_xcom == {
+ 'key': 'test_key',
+ 'timestamp': self.default_time,
+ 'execution_date': self.default_time,
+ 'task_id': 'test_task_id',
+ 'dag_id': 'test_dag',
+ }
def test_deserialize(self):
xcom_dump = {
@@ -84,16 +81,13 @@ def test_deserialize(self):
'dag_id': 'test_dag',
}
result = xcom_collection_item_schema.load(xcom_dump)
- self.assertEqual(
- result,
- {
- 'key': 'test_key',
- 'timestamp': self.default_time_parsed,
- 'execution_date': self.default_time_parsed,
- 'task_id': 'test_task_id',
- 'dag_id': 'test_dag',
- },
- )
+ assert result == {
+ 'key': 'test_key',
+ 'timestamp': self.default_time_parsed,
+ 'execution_date': self.default_time_parsed,
+ 'task_id': 'test_task_id',
+ 'dag_id': 'test_dag',
+ }
class TestXComCollectionSchema(TestXComSchemaBase):
@@ -133,28 +127,25 @@ def test_serialize(self, session):
total_entries=xcom_models_query.count(),
)
)
- self.assertEqual(
- deserialized_xcoms,
- {
- 'xcom_entries': [
- {
- 'key': 'test_key_1',
- 'timestamp': self.default_time_1,
- 'execution_date': self.default_time_1,
- 'task_id': 'test_task_id_1',
- 'dag_id': 'test_dag_1',
- },
- {
- 'key': 'test_key_2',
- 'timestamp': self.default_time_2,
- 'execution_date': self.default_time_2,
- 'task_id': 'test_task_id_2',
- 'dag_id': 'test_dag_2',
- },
- ],
- 'total_entries': len(xcom_models),
- },
- )
+ assert deserialized_xcoms == {
+ 'xcom_entries': [
+ {
+ 'key': 'test_key_1',
+ 'timestamp': self.default_time_1,
+ 'execution_date': self.default_time_1,
+ 'task_id': 'test_task_id_1',
+ 'dag_id': 'test_dag_1',
+ },
+ {
+ 'key': 'test_key_2',
+ 'timestamp': self.default_time_2,
+ 'execution_date': self.default_time_2,
+ 'task_id': 'test_task_id_2',
+ 'dag_id': 'test_dag_2',
+ },
+ ],
+ 'total_entries': len(xcom_models),
+ }
class TestXComSchema(TestXComSchemaBase):
@@ -177,17 +168,14 @@ def test_serialize(self, session):
session.commit()
xcom_model = session.query(XCom).first()
deserialized_xcom = xcom_schema.dump(xcom_model)
- self.assertEqual(
- deserialized_xcom,
- {
- 'key': 'test_key',
- 'timestamp': self.default_time,
- 'execution_date': self.default_time,
- 'task_id': 'test_task_id',
- 'dag_id': 'test_dag',
- 'value': 'test_binary',
- },
- )
+ assert deserialized_xcom == {
+ 'key': 'test_key',
+ 'timestamp': self.default_time,
+ 'execution_date': self.default_time,
+ 'task_id': 'test_task_id',
+ 'dag_id': 'test_dag',
+ 'value': 'test_binary',
+ }
def test_deserialize(self):
xcom_dump = {
@@ -199,14 +187,11 @@ def test_deserialize(self):
'value': b'test_binary',
}
result = xcom_schema.load(xcom_dump)
- self.assertEqual(
- result,
- {
- 'key': 'test_key',
- 'timestamp': self.default_time_parsed,
- 'execution_date': self.default_time_parsed,
- 'task_id': 'test_task_id',
- 'dag_id': 'test_dag',
- 'value': 'test_binary',
- },
- )
+ assert result == {
+ 'key': 'test_key',
+ 'timestamp': self.default_time_parsed,
+ 'execution_date': self.default_time_parsed,
+ 'task_id': 'test_task_id',
+ 'dag_id': 'test_dag',
+ 'value': 'test_binary',
+ }
diff --git a/tests/api_connexion/test_error_handling.py b/tests/api_connexion/test_error_handling.py
index e921aea88f41d..cfd33dadf5c5c 100644
--- a/tests/api_connexion/test_error_handling.py
+++ b/tests/api_connexion/test_error_handling.py
@@ -37,7 +37,7 @@ def test_incorrect_endpoint_should_return_json(self):
# Then we have parsable JSON as output
- self.assertEqual(404, resp_json["status"])
+ assert 404 == resp_json["status"]
# When we are hitting non-api incorrect enpoint
@@ -45,8 +45,8 @@ def test_incorrect_endpoint_should_return_json(self):
# Then we do not have JSON as response, rather standard HTML
- self.assertIsNone(resp_json)
+ assert resp_json is None
resp_json = self.client.put("/api/v1/variables").json
- self.assertEqual('Method Not Allowed', resp_json["title"])
+ assert 'Method Not Allowed' == resp_json["title"]
diff --git a/tests/api_connexion/test_parameters.py b/tests/api_connexion/test_parameters.py
index 1f625dc71feea..6b9e59c33e514 100644
--- a/tests/api_connexion/test_parameters.py
+++ b/tests/api_connexion/test_parameters.py
@@ -18,6 +18,7 @@
import unittest
from unittest import mock
+import pytest
from pendulum import DateTime
from pendulum.tz.timezone import Timezone
@@ -40,7 +41,7 @@ def setUp(self) -> None:
self.timezoned = datetime.now(tz=timezone.utc)
def test_gives_400_for_naive(self):
- with self.assertRaises(BadRequest):
+ with pytest.raises(BadRequest):
validate_istimezone(self.naive)
def test_timezone_passes(self):
@@ -66,7 +67,7 @@ def test_works_with_datestring_ending_with_zed(self):
def test_raises_400_for_invalid_arg(self):
invalid_datetime = '2020-06-13T22:44:00P'
- with self.assertRaises(BadRequest):
+ with pytest.raises(BadRequest):
format_datetime(invalid_datetime)
@@ -74,26 +75,26 @@ class TestMaximumPagelimit(unittest.TestCase):
@conf_vars({("api", "maximum_page_limit"): "320"})
def test_maximum_limit_return_val(self):
limit = check_limit(300)
- self.assertEqual(limit, 300)
+ assert limit == 300
@conf_vars({("api", "maximum_page_limit"): "320"})
def test_maximum_limit_returns_configured_if_limit_above_conf(self):
limit = check_limit(350)
- self.assertEqual(limit, 320)
+ assert limit == 320
@conf_vars({("api", "maximum_page_limit"): "1000"})
def test_limit_returns_set_max_if_give_limit_is_exceeded(self):
limit = check_limit(1500)
- self.assertEqual(limit, 1000)
+ assert limit == 1000
@conf_vars({("api", "fallback_page_limit"): "100"})
def test_limit_of_zero_returns_default(self):
limit = check_limit(0)
- self.assertEqual(limit, 100)
+ assert limit == 100
@conf_vars({("api", "maximum_page_limit"): "1500"})
def test_negative_limit_raises(self):
- with self.assertRaises(BadRequest):
+ with pytest.raises(BadRequest):
check_limit(-1)
@@ -111,7 +112,7 @@ def test_should_propagate_exceptions(self):
decorator = format_parameters({"param_a": format_datetime})
endpoint = mock.MagicMock()
decorated_endpoint = decorator(endpoint)
- with self.assertRaises(BadRequest):
+ with pytest.raises(BadRequest):
decorated_endpoint(param_a='XXXXX')
@conf_vars({("api", "maximum_page_limit"): "100"})
diff --git a/tests/cli/commands/test_celery_command.py b/tests/cli/commands/test_celery_command.py
index 2a37e270f83d8..41893b7b1700e 100644
--- a/tests/cli/commands/test_celery_command.py
+++ b/tests/cli/commands/test_celery_command.py
@@ -38,9 +38,9 @@ def test_error(self, mock_validate_session):
by mocking validate_session method
"""
mock_validate_session.return_value = False
- with self.assertRaises(SystemExit) as cm:
+ with pytest.raises(SystemExit) as ctx:
celery_command.worker(Namespace(queues=1, concurrency=1))
- self.assertEqual(str(cm.exception), "Worker exiting, database connection precheck failed.")
+ assert str(ctx.value) == "Worker exiting, database connection precheck failed."
@conf_vars({('celery', 'worker_precheck'): 'False'})
def test_worker_precheck_exception(self):
@@ -48,7 +48,7 @@ def test_worker_precheck_exception(self):
Test to check the behaviour of validate_session method
when worker_precheck is absent in airflow configuration
"""
- self.assertTrue(airflow.settings.validate_session())
+ assert airflow.settings.validate_session()
@mock.patch('sqlalchemy.orm.session.Session.execute')
@conf_vars({('celery', 'worker_precheck'): 'True'})
@@ -57,7 +57,7 @@ def test_validate_session_dbapi_exception(self, mock_session):
Test to validate connection failure scenario on SELECT 1 query
"""
mock_session.side_effect = sqlalchemy.exc.OperationalError("m1", "m2", "m3", "m4")
- self.assertEqual(airflow.settings.validate_session(), False)
+ assert airflow.settings.validate_session() is False
@pytest.mark.integration("redis")
@@ -105,7 +105,7 @@ def test_if_right_pid_is_read(self, mock_process, mock_setup_locations):
pid = "123"
# Calling stop_worker should delete the temporary pid file
- with self.assertRaises(FileNotFoundError):
+ with pytest.raises(FileNotFoundError):
with NamedTemporaryFile("w+") as f:
# Create pid file
f.write(pid)
diff --git a/tests/cli/commands/test_cheat_sheet_command.py b/tests/cli/commands/test_cheat_sheet_command.py
index 3928106d991d8..5edc4b7f7b85f 100644
--- a/tests/cli/commands/test_cheat_sheet_command.py
+++ b/tests/cli/commands/test_cheat_sheet_command.py
@@ -100,6 +100,6 @@ def test_should_display_index(self):
args = self.parser.parse_args(['cheat-sheet'])
args.func(args)
output = temp_stdout.getvalue()
- self.assertIn(ALL_COMMANDS, output)
- self.assertIn(SECTION_A, output)
- self.assertIn(SECTION_E, output)
+ assert ALL_COMMANDS in output
+ assert SECTION_A in output
+ assert SECTION_E in output
diff --git a/tests/cli/commands/test_config_command.py b/tests/cli/commands/test_config_command.py
index 532a95f181281..b665db7a4092e 100644
--- a/tests/cli/commands/test_config_command.py
+++ b/tests/cli/commands/test_config_command.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.cli import cli_parser
from airflow.cli.commands import config_command
from tests.test_utils.config import conf_vars
@@ -39,8 +41,8 @@ def test_cli_show_config_should_write_data(self, mock_conf, mock_stringio):
def test_cli_show_config_should_display_key(self):
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
config_command.show_config(self.parser.parse_args(['config', 'list', '--color', 'off']))
- self.assertIn('[core]', temp_stdout.getvalue())
- self.assertIn('testkey = test_value', temp_stdout.getvalue())
+ assert '[core]' in temp_stdout.getvalue()
+ assert 'testkey = test_value' in temp_stdout.getvalue()
class TestCliConfigGetValue(unittest.TestCase):
@@ -53,28 +55,26 @@ def test_should_display_value(self):
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
config_command.get_value(self.parser.parse_args(['config', 'get-value', 'core', 'test_key']))
- self.assertEqual("test_value", temp_stdout.getvalue().strip())
+ assert "test_value" == temp_stdout.getvalue().strip()
@mock.patch("airflow.cli.commands.config_command.conf")
def test_should_raise_exception_when_section_is_missing(self, mock_conf):
mock_conf.has_section.return_value = False
mock_conf.has_option.return_value = True
- with self.assertRaises(SystemExit) as err:
+ with pytest.raises(SystemExit) as ctx:
config_command.get_value(
self.parser.parse_args(['config', 'get-value', 'missing-section', 'dags_folder'])
)
- self.assertEqual("The section [missing-section] is not found in config.", str(err.exception))
+ assert "The section [missing-section] is not found in config." == str(ctx.value)
@mock.patch("airflow.cli.commands.config_command.conf")
def test_should_raise_exception_when_option_is_missing(self, mock_conf):
mock_conf.has_section.return_value = True
mock_conf.has_option.return_value = False
- with self.assertRaises(SystemExit) as err:
+ with pytest.raises(SystemExit) as ctx:
config_command.get_value(
self.parser.parse_args(['config', 'get-value', 'missing-section', 'dags_folder'])
)
- self.assertEqual(
- "The option [missing-section/dags_folder] is not found in config.", str(err.exception)
- )
+ assert "The option [missing-section/dags_folder] is not found in config." == str(ctx.value)
diff --git a/tests/cli/commands/test_connection_command.py b/tests/cli/commands/test_connection_command.py
index 1da1174a995c1..ae78892ff755f 100644
--- a/tests/cli/commands/test_connection_command.py
+++ b/tests/cli/commands/test_connection_command.py
@@ -22,6 +22,7 @@
from contextlib import redirect_stdout
from unittest import mock
+import pytest
from parameterized import parameterized
from airflow.cli import cli_parser
@@ -46,10 +47,10 @@ def test_cli_connection_get(self):
self.parser.parse_args(["connections", "get", "google_cloud_default", "--output", "json"])
)
stdout = stdout.getvalue()
- self.assertIn("google-cloud-platform:///default", stdout)
+ assert "google-cloud-platform:///default" in stdout
def test_cli_connection_get_invalid(self):
- with self.assertRaisesRegex(SystemExit, re.escape("Connection not found.")):
+ with pytest.raises(SystemExit, match=re.escape("Connection not found.")):
connection_command.connections_get(self.parser.parse_args(["connections", "get", "INVALID"]))
@@ -120,8 +121,8 @@ def test_cli_connections_list_as_json(self):
stdout = stdout.getvalue()
for conn_id, conn_type in self.EXPECTED_CONS:
- self.assertIn(conn_type, stdout)
- self.assertIn(conn_id, stdout)
+ assert conn_type in stdout
+ assert conn_id in stdout
def test_cli_connections_filter_conn_id(self):
args = self.parser.parse_args(
@@ -132,7 +133,7 @@ def test_cli_connections_filter_conn_id(self):
connection_command.connections_list(args)
stdout = stdout.getvalue()
- self.assertIn("http_default", stdout)
+ assert "http_default" in stdout
class TestCliExportConnections(unittest.TestCase):
@@ -169,7 +170,7 @@ def tearDown(self):
clear_db_connections()
def test_cli_connections_export_should_return_error_for_invalid_command(self):
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
self.parser.parse_args(
[
"connections",
@@ -178,7 +179,7 @@ def test_cli_connections_export_should_return_error_for_invalid_command(self):
)
def test_cli_connections_export_should_return_error_for_invalid_format(self):
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
self.parser.parse_args(["connections", "export", "--format", "invalid", "/path/to/file"])
@mock.patch('os.path.splitext')
@@ -196,8 +197,8 @@ def test_cli_connections_export_should_return_error_for_invalid_export_format(
output_filepath,
]
)
- with self.assertRaisesRegex(
- SystemExit, r"Unsupported file format. The file must have the extension .yaml, .json, .env"
+ with pytest.raises(
+ SystemExit, match=r"Unsupported file format. The file must have the extension .yaml, .json, .env"
):
connection_command.connections_export(args)
@@ -226,7 +227,7 @@ def my_side_effect():
output_filepath,
]
)
- with self.assertRaisesRegex(Exception, r"dummy exception"):
+ with pytest.raises(Exception, match=r"dummy exception"):
connection_command.connections_export(args)
mock_splittext.assert_not_called()
@@ -256,7 +257,7 @@ def my_side_effect(_):
output_filepath,
]
)
- with self.assertRaisesRegex(Exception, r"dummy exception"):
+ with pytest.raises(Exception, match=r"dummy exception"):
connection_command.connections_export(args)
mock_splittext.assert_called_once()
@@ -396,7 +397,7 @@ def test_cli_connections_export_should_export_as_env(self, mock_file_open, mock_
mock_splittext.assert_called_once()
mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None)
mock_file_open.return_value.write.assert_called_once_with(mock.ANY)
- self.assertIn(mock_file_open.return_value.write.call_args_list[0][0][0], expected_connections)
+ assert mock_file_open.return_value.write.call_args_list[0][0][0] in expected_connections
@mock.patch('os.path.splitext')
@mock.patch('builtins.open', new_callable=mock.mock_open())
@@ -425,7 +426,7 @@ def test_cli_connections_export_should_export_as_env_for_uppercase_file_extensio
mock_splittext.assert_called_once()
mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None)
mock_file_open.return_value.write.assert_called_once_with(mock.ANY)
- self.assertIn(mock_file_open.return_value.write.call_args_list[0][0][0], expected_connections)
+ assert mock_file_open.return_value.write.call_args_list[0][0][0] in expected_connections
@mock.patch('os.path.splitext')
@mock.patch('builtins.open', new_callable=mock.mock_open())
@@ -631,7 +632,7 @@ def test_cli_connection_add(self, cmd, expected_output, expected_conn):
stdout = stdout.getvalue()
- self.assertIn(expected_output, stdout)
+ assert expected_output in stdout
conn_id = cmd[2]
with create_session() as session:
comparable_attrs = [
@@ -645,7 +646,7 @@ def test_cli_connection_add(self, cmd, expected_output, expected_conn):
"schema",
]
current_conn = session.query(Connection).filter(Connection.conn_id == conn_id).first()
- self.assertEqual(expected_conn, {attr: getattr(current_conn, attr) for attr in comparable_attrs})
+ assert expected_conn == {attr: getattr(current_conn, attr) for attr in comparable_attrs}
def test_cli_connections_add_duplicate(self):
conn_id = "to_be_duplicated"
@@ -653,21 +654,22 @@ def test_cli_connections_add_duplicate(self):
self.parser.parse_args(["connections", "add", conn_id, "--conn-uri=%s" % TEST_URL])
)
# Check for addition attempt
- with self.assertRaisesRegex(SystemExit, rf"A connection with `conn_id`={conn_id} already exists"):
+ with pytest.raises(SystemExit, match=rf"A connection with `conn_id`={conn_id} already exists"):
connection_command.connections_add(
self.parser.parse_args(["connections", "add", conn_id, "--conn-uri=%s" % TEST_URL])
)
def test_cli_connections_add_delete_with_missing_parameters(self):
# Attempt to add without providing conn_uri
- with self.assertRaisesRegex(
- SystemExit, r"The following args are required to add a connection: \['conn-uri or conn-type'\]"
+ with pytest.raises(
+ SystemExit,
+ match=r"The following args are required to add a connection: \['conn-uri or conn-type'\]",
):
connection_command.connections_add(self.parser.parse_args(["connections", "add", "new1"]))
def test_cli_connections_add_invalid_uri(self):
# Attempt to add with invalid uri
- with self.assertRaisesRegex(SystemExit, r"The URI provided to --conn-uri is invalid: nonsense_uri"):
+ with pytest.raises(SystemExit, match=r"The URI provided to --conn-uri is invalid: nonsense_uri"):
connection_command.connections_add(
self.parser.parse_args(["connections", "add", "new1", "--conn-uri=%s" % "nonsense_uri"])
)
@@ -703,14 +705,14 @@ def test_cli_delete_connections(self, session=None):
stdout = stdout.getvalue()
# Check deletion stdout
- self.assertIn("Successfully deleted connection with `conn_id`=new1", stdout)
+ assert "Successfully deleted connection with `conn_id`=new1" in stdout
# Check deletions
result = session.query(Connection).filter(Connection.conn_id == "new1").first()
- self.assertTrue(result is None)
+ assert result is None
def test_cli_delete_invalid_connection(self):
# Attempt to delete a non-existing connection
- with self.assertRaisesRegex(SystemExit, r"Did not find a connection with `conn_id`=fake"):
+ with pytest.raises(SystemExit, match=r"Did not find a connection with `conn_id`=fake"):
connection_command.connections_delete(self.parser.parse_args(["connections", "delete", "fake"]))
diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py
index 173a7bf54af0d..ed696b083acac 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -23,6 +23,8 @@
from datetime import datetime, timedelta
from unittest import mock
+import pytest
+
from airflow import settings
from airflow.cli import cli_parser
from airflow.cli.commands import dag_command
@@ -104,8 +106,8 @@ def test_backfill(self, mock_run):
)
output = stdout.getvalue()
- self.assertIn(f"Dry run of DAG example_bash_operator on {DEFAULT_DATE.isoformat()}\n", output)
- self.assertIn("Task runme_0\n", output)
+ assert f"Dry run of DAG example_bash_operator on {DEFAULT_DATE.isoformat()}\n" in output
+ assert "Task runme_0\n" in output
mock_run.assert_not_called() # Dry run shouldn't run the backfill
@@ -160,9 +162,9 @@ def test_show_dag_print(self):
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
dag_command.dag_show(self.parser.parse_args(['dags', 'show', 'example_bash_operator']))
out = temp_stdout.getvalue()
- self.assertIn("label=example_bash_operator", out)
- self.assertIn("graph [label=example_bash_operator labelloc=t rankdir=LR]", out)
- self.assertIn("runme_2 -> run_after_loop", out)
+ assert "label=example_bash_operator" in out
+ assert "graph [label=example_bash_operator labelloc=t rankdir=LR]" in out
+ assert "runme_2 -> run_after_loop" in out
@mock.patch("airflow.cli.commands.dag_command.render_dag")
def test_show_dag_dave(self, mock_render_dag):
@@ -174,7 +176,7 @@ def test_show_dag_dave(self, mock_render_dag):
mock_render_dag.return_value.render.assert_called_once_with(
cleanup=True, filename='awesome', format='png'
)
- self.assertIn("File awesome.png saved", out)
+ assert "File awesome.png saved" in out
@mock.patch("airflow.cli.commands.dag_command.subprocess.Popen")
@mock.patch("airflow.cli.commands.dag_command.render_dag")
@@ -188,8 +190,8 @@ def test_show_dag_imgcat(self, mock_render_dag, mock_popen):
out = temp_stdout.getvalue()
mock_render_dag.return_value.pipe.assert_called_once_with(format='png')
mock_popen.return_value.communicate.assert_called_once_with(b'DOT_DATA')
- self.assertIn("OUT", out)
- self.assertIn("ERR", out)
+ assert "OUT" in out
+ assert "ERR" in out
@mock.patch("airflow.cli.commands.dag_command.DAG.run")
def test_cli_backfill_depends_on_past(self, mock_run):
@@ -289,7 +291,7 @@ def test_next_execution(self):
out = temp_stdout.getvalue()
# `next_execution` function is inapplicable if no execution record found
# It prints `None` in such cases
- self.assertIn("None", out)
+ assert "None" in out
# The details below is determined by the schedule_interval of example DAGs
now = DEFAULT_DATE
@@ -313,14 +315,14 @@ def test_next_execution(self):
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
dag_command.dag_next_execution(args)
out = temp_stdout.getvalue()
- self.assertIn(expected_output[i], out)
+ assert expected_output[i] in out
# Test num-executions = 2
args = self.parser.parse_args(['dags', 'next-execution', dag_id, '--num-executions', '2'])
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
dag_command.dag_next_execution(args)
out = temp_stdout.getvalue()
- self.assertIn(expected_output_2[i], out)
+ assert expected_output_2[i] in out
# Clean up before leaving
with create_session() as session:
@@ -334,8 +336,8 @@ def test_cli_report(self):
dag_command.dag_report(args)
out = temp_stdout.getvalue()
- self.assertIn("airflow/example_dags/example_complex.py", out)
- self.assertIn("example_complex", out)
+ assert "airflow/example_dags/example_complex.py" in out
+ assert "example_complex" in out
@conf_vars({('core', 'load_examples'): 'true'})
def test_cli_list_dags(self):
@@ -343,11 +345,11 @@ def test_cli_list_dags(self):
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
dag_command.dag_list_dags(args)
out = temp_stdout.getvalue()
- self.assertIn("owner", out)
- self.assertIn("airflow", out)
- self.assertIn("paused", out)
- self.assertIn("airflow/example_dags/example_complex.py", out)
- self.assertIn("False", out)
+ assert "owner" in out
+ assert "airflow" in out
+ assert "paused" in out
+ assert "airflow/example_dags/example_complex.py" in out
+ assert "False" in out
def test_cli_list_dag_runs(self):
dag_command.dag_trigger(
@@ -394,31 +396,30 @@ def test_cli_list_jobs_with_args(self):
def test_pause(self):
args = self.parser.parse_args(['dags', 'pause', 'example_bash_operator'])
dag_command.dag_pause(args)
- self.assertIn(self.dagbag.dags['example_bash_operator'].get_is_paused(), [True, 1])
+ assert self.dagbag.dags['example_bash_operator'].get_is_paused() in [True, 1]
args = self.parser.parse_args(['dags', 'unpause', 'example_bash_operator'])
dag_command.dag_unpause(args)
- self.assertIn(self.dagbag.dags['example_bash_operator'].get_is_paused(), [False, 0])
+ assert self.dagbag.dags['example_bash_operator'].get_is_paused() in [False, 0]
def test_trigger_dag(self):
dag_command.dag_trigger(
self.parser.parse_args(['dags', 'trigger', 'example_bash_operator', '--conf', '{"foo": "bar"}'])
)
- self.assertRaises(
- ValueError,
- dag_command.dag_trigger,
- self.parser.parse_args(
- [
- 'dags',
- 'trigger',
- 'example_bash_operator',
- '--run-id',
- 'trigger_dag_xxx',
- '--conf',
- 'NOT JSON',
- ]
- ),
- )
+ with pytest.raises(ValueError):
+ dag_command.dag_trigger(
+ self.parser.parse_args(
+ [
+ 'dags',
+ 'trigger',
+ 'example_bash_operator',
+ '--run-id',
+ 'trigger_dag_xxx',
+ '--conf',
+ 'NOT JSON',
+ ]
+ ),
+ )
def test_delete_dag(self):
DM = DagModel
@@ -427,12 +428,11 @@ def test_delete_dag(self):
session.add(DM(dag_id=key))
session.commit()
dag_command.dag_delete(self.parser.parse_args(['dags', 'delete', key, '--yes']))
- self.assertEqual(session.query(DM).filter_by(dag_id=key).count(), 0)
- self.assertRaises(
- AirflowException,
- dag_command.dag_delete,
- self.parser.parse_args(['dags', 'delete', 'does_not_exist_dag', '--yes']),
- )
+ assert session.query(DM).filter_by(dag_id=key).count() == 0
+ with pytest.raises(AirflowException):
+ dag_command.dag_delete(
+ self.parser.parse_args(['dags', 'delete', 'does_not_exist_dag', '--yes']),
+ )
def test_delete_dag_existing_file(self):
# Test to check that the DAG should be deleted even if
@@ -444,18 +444,18 @@ def test_delete_dag_existing_file(self):
session.add(DM(dag_id=key, fileloc=f.name))
session.commit()
dag_command.dag_delete(self.parser.parse_args(['dags', 'delete', key, '--yes']))
- self.assertEqual(session.query(DM).filter_by(dag_id=key).count(), 0)
+ assert session.query(DM).filter_by(dag_id=key).count() == 0
def test_cli_list_jobs(self):
args = self.parser.parse_args(['dags', 'list-jobs'])
dag_command.dag_list_jobs(args)
def test_dag_state(self):
- self.assertEqual(
- None,
+ assert (
dag_command.dag_state(
self.parser.parse_args(['dags', 'state', 'example_bash_operator', DEFAULT_DATE.isoformat()])
- ),
+ )
+ is None
)
@mock.patch("airflow.cli.commands.dag_command.DebugExecutor")
@@ -510,4 +510,4 @@ def test_dag_test_show_dag(self, mock_get_dag, mock_executor, mock_render_dag):
]
)
mock_render_dag.assert_has_calls([mock.call(mock_get_dag.return_value, tis=[])])
- self.assertIn("SOURCE", output)
+ assert "SOURCE" in output
diff --git a/tests/cli/commands/test_db_command.py b/tests/cli/commands/test_db_command.py
index 9c6ad553a188e..4a53b986e95b1 100644
--- a/tests/cli/commands/test_db_command.py
+++ b/tests/cli/commands/test_db_command.py
@@ -18,6 +18,7 @@
import unittest
from unittest import mock
+import pytest
from sqlalchemy.engine.url import make_url
from airflow.cli import cli_parser
@@ -99,16 +100,13 @@ def test_cli_shell_postgres(self, mock_execute_interactive):
_, kwargs = mock_execute_interactive.call_args
env = kwargs['env']
postgres_env = {k: v for k, v in env.items() if k.startswith('PG')}
- self.assertEqual(
- {
- 'PGDATABASE': 'airflow',
- 'PGHOST': 'postgres',
- 'PGPASSWORD': 'airflow',
- 'PGPORT': '5432',
- 'PGUSER': 'postgres',
- },
- postgres_env,
- )
+ assert {
+ 'PGDATABASE': 'airflow',
+ 'PGHOST': 'postgres',
+ 'PGPASSWORD': 'airflow',
+ 'PGPORT': '5432',
+ 'PGUSER': 'postgres',
+ } == postgres_env
@mock.patch("airflow.cli.commands.db_command.execute_interactive")
@mock.patch(
@@ -121,21 +119,18 @@ def test_cli_shell_postgres_without_port(self, mock_execute_interactive):
_, kwargs = mock_execute_interactive.call_args
env = kwargs['env']
postgres_env = {k: v for k, v in env.items() if k.startswith('PG')}
- self.assertEqual(
- {
- 'PGDATABASE': 'airflow',
- 'PGHOST': 'postgres',
- 'PGPASSWORD': 'airflow',
- 'PGPORT': '5432',
- 'PGUSER': 'postgres',
- },
- postgres_env,
- )
+ assert {
+ 'PGDATABASE': 'airflow',
+ 'PGHOST': 'postgres',
+ 'PGPASSWORD': 'airflow',
+ 'PGPORT': '5432',
+ 'PGUSER': 'postgres',
+ } == postgres_env
@mock.patch(
"airflow.cli.commands.db_command.settings.engine.url",
make_url("invalid+psycopg2://postgres:airflow@postgres/airflow"),
)
def test_cli_shell_invalid(self):
- with self.assertRaisesRegex(AirflowException, r"Unknown driver: invalid\+psycopg2"):
+ with pytest.raises(AirflowException, match=r"Unknown driver: invalid\+psycopg2"):
db_command.shell(self.parser.parse_args(['db', 'shell']))
diff --git a/tests/cli/commands/test_info_command.py b/tests/cli/commands/test_info_command.py
index e6c8de47b435b..7fad6e83dcbfa 100644
--- a/tests/cli/commands/test_info_command.py
+++ b/tests/cli/commands/test_info_command.py
@@ -46,7 +46,7 @@ def setUp(self) -> None:
def test_should_remove_pii_from_path(self):
home_path = os.path.expanduser("~/airflow/config")
- self.assertEqual("${HOME}/airflow/config", self.instance.process_path(home_path))
+ assert "${HOME}/airflow/config" == self.instance.process_path(home_path)
@parameterized.expand(
[
@@ -69,29 +69,29 @@ def test_should_remove_pii_from_path(self):
]
)
def test_should_remove_pii_from_url(self, before, after):
- self.assertEqual(after, self.instance.process_url(before))
+ assert after == self.instance.process_url(before)
class TestAirflowInfo(unittest.TestCase):
def test_info(self):
instance = info_command.AirflowInfo(info_command.NullAnonymizer())
text = capture_show_output(instance)
- self.assertIn("Apache Airflow", text)
- self.assertIn(airflow_version, text)
+ assert "Apache Airflow" in text
+ assert airflow_version in text
class TestSystemInfo(unittest.TestCase):
def test_info(self):
instance = info_command.SystemInfo(info_command.NullAnonymizer())
text = capture_show_output(instance)
- self.assertIn("System info", text)
+ assert "System info" in text
class TestPathsInfo(unittest.TestCase):
def test_info(self):
instance = info_command.PathsInfo(info_command.NullAnonymizer())
text = capture_show_output(instance)
- self.assertIn("Paths info", text)
+ assert "Paths info" in text
class TestConfigInfo(unittest.TestCase):
@@ -107,11 +107,11 @@ class TestConfigInfo(unittest.TestCase):
def test_should_read_config(self):
instance = info_command.ConfigInfo(info_command.NullAnonymizer())
text = capture_show_output(instance)
- self.assertIn("TEST_EXECUTOR", text)
- self.assertIn("TEST_DAGS_FOLDER", text)
- self.assertIn("TEST_PLUGINS_FOLDER", text)
- self.assertIn("TEST_LOG_FOLDER", text)
- self.assertIn("postgresql+psycopg2://postgres:airflow@postgres/airflow", text)
+ assert "TEST_EXECUTOR" in text
+ assert "TEST_DAGS_FOLDER" in text
+ assert "TEST_PLUGINS_FOLDER" in text
+ assert "TEST_LOG_FOLDER" in text
+ assert "postgresql+psycopg2://postgres:airflow@postgres/airflow" in text
class TestConfigInfoLogging(unittest.TestCase):
@@ -126,7 +126,7 @@ def test_should_read_logging_configuration(self):
configure_logging()
instance = info_command.ConfigInfo(info_command.NullAnonymizer())
text = capture_show_output(instance)
- self.assertIn("stackdriver", text)
+ assert "stackdriver" in text
def tearDown(self) -> None:
importlib.reload(airflow_local_settings)
@@ -148,8 +148,8 @@ def test_show_info(self):
info_command.show_info(self.parser.parse_args(["info"]))
output = stdout.getvalue()
- self.assertIn(f"Apache Airflow: {airflow_version}", output)
- self.assertIn("postgresql+psycopg2://postgres:airflow@postgres/airflow", output)
+ assert f"Apache Airflow: {airflow_version}" in output
+ assert "postgresql+psycopg2://postgres:airflow@postgres/airflow" in output
@conf_vars(
{
@@ -161,8 +161,8 @@ def test_show_info_anonymize(self):
info_command.show_info(self.parser.parse_args(["info", "--anonymize"]))
output = stdout.getvalue()
- self.assertIn(f"Apache Airflow: {airflow_version}", output)
- self.assertIn("postgresql+psycopg2://p...s:PASSWORD@postgres/airflow", output)
+ assert f"Apache Airflow: {airflow_version}" in output
+ assert "postgresql+psycopg2://p...s:PASSWORD@postgres/airflow" in output
@conf_vars(
{
@@ -185,6 +185,6 @@ def test_show_info_anonymize_fileio(self, mock_requests):
with contextlib.redirect_stdout(io.StringIO()) as stdout:
info_command.show_info(self.parser.parse_args(["info", "--file-io"]))
- self.assertIn("https://file.io/TEST", stdout.getvalue())
+ assert "https://file.io/TEST" in stdout.getvalue()
content = mock_requests.post.call_args[1]["data"]["text"]
- self.assertIn("postgresql+psycopg2://p...s:PASSWORD@postgres/airflow", content)
+ assert "postgresql+psycopg2://p...s:PASSWORD@postgres/airflow" in content
diff --git a/tests/cli/commands/test_kubernetes_command.py b/tests/cli/commands/test_kubernetes_command.py
index 1a6773e51066e..8ae2eef052f79 100644
--- a/tests/cli/commands/test_kubernetes_command.py
+++ b/tests/cli/commands/test_kubernetes_command.py
@@ -47,11 +47,11 @@ def test_generate_dag_yaml(self):
]
)
)
- self.assertEqual(len(os.listdir(directory)), 1)
+ assert len(os.listdir(directory)) == 1
out_dir = directory + "/airflow_yaml_output/"
- self.assertEqual(len(os.listdir(out_dir)), 6)
- self.assertTrue(os.path.isfile(out_dir + file_name))
- self.assertGreater(os.stat(out_dir + file_name).st_size, 0)
+ assert len(os.listdir(out_dir)) == 6
+ assert os.path.isfile(out_dir + file_name)
+ assert os.stat(out_dir + file_name).st_size > 0
class TestCleanUpPodsCommand(unittest.TestCase):
diff --git a/tests/cli/commands/test_legacy_commands.py b/tests/cli/commands/test_legacy_commands.py
index 444cda07c1f08..c8d054542fe00 100644
--- a/tests/cli/commands/test_legacy_commands.py
+++ b/tests/cli/commands/test_legacy_commands.py
@@ -20,6 +20,8 @@
from argparse import ArgumentError
from unittest.mock import MagicMock
+import pytest
+
from airflow.cli import cli_parser
from airflow.cli.commands import config_command
from airflow.cli.commands.legacy_commands import COMMAND_MAP, check_legacy_command
@@ -61,27 +63,24 @@ def setUpClass(cls):
cls.parser = cli_parser.get_parser()
def test_should_display_value(self):
- with self.assertRaises(SystemExit) as cm_exception, contextlib.redirect_stderr(
- io.StringIO()
- ) as temp_stderr:
+ with pytest.raises(SystemExit) as ctx, contextlib.redirect_stderr(io.StringIO()) as temp_stderr:
config_command.get_value(self.parser.parse_args(['worker']))
- self.assertEqual(2, cm_exception.exception.code)
- self.assertIn(
+ assert 2 == ctx.value.code
+ assert (
"`airflow worker` command, has been removed, "
- "please use `airflow celery worker`, see help above.",
- temp_stderr.getvalue().strip(),
+ "please use `airflow celery worker`, see help above." in temp_stderr.getvalue().strip()
)
def test_command_map(self):
for item in LEGACY_COMMANDS:
- self.assertIsNotNone(COMMAND_MAP[item])
+ assert COMMAND_MAP[item] is not None
def test_check_legacy_command(self):
action = MagicMock()
- with self.assertRaises(ArgumentError) as e:
+ with pytest.raises(ArgumentError) as ctx:
check_legacy_command(action, 'list_users')
- self.assertEqual(
- str(e.exception),
- "argument : `airflow list_users` command, has been removed, please use `airflow users list`",
+ assert (
+ str(ctx.value)
+ == "argument : `airflow list_users` command, has been removed, please use `airflow users list`"
)
diff --git a/tests/cli/commands/test_plugins_command.py b/tests/cli/commands/test_plugins_command.py
index bbaaad77b6698..262b59bfac1f5 100644
--- a/tests/cli/commands/test_plugins_command.py
+++ b/tests/cli/commands/test_plugins_command.py
@@ -46,7 +46,7 @@ def test_should_display_no_plugins(self):
with redirect_stdout(io.StringIO()) as temp_stdout:
plugins_command.dump_plugins(self.parser.parse_args(['plugins', '--output=json']))
stdout = temp_stdout.getvalue()
- self.assertIn('No plugins loaded', stdout)
+ assert 'No plugins loaded' in stdout
@mock_plugin_manager(plugins=[TestPlugin])
def test_should_display_one_plugins(self):
diff --git a/tests/cli/commands/test_pool_command.py b/tests/cli/commands/test_pool_command.py
index d40e18786a1cf..92fb46db223c1 100644
--- a/tests/cli/commands/test_pool_command.py
+++ b/tests/cli/commands/test_pool_command.py
@@ -22,6 +22,8 @@
import unittest
from contextlib import redirect_stdout
+import pytest
+
from airflow import models, settings
from airflow.cli import cli_parser
from airflow.cli.commands import pool_command
@@ -59,14 +61,14 @@ def test_pool_list(self):
with redirect_stdout(io.StringIO()) as stdout:
pool_command.pool_list(self.parser.parse_args(['pools', 'list']))
- self.assertIn('foo', stdout.getvalue())
+ assert 'foo' in stdout.getvalue()
def test_pool_list_with_args(self):
pool_command.pool_list(self.parser.parse_args(['pools', 'list', '--output', 'json']))
def test_pool_create(self):
pool_command.pool_set(self.parser.parse_args(['pools', 'set', 'foo', '1', 'test']))
- self.assertEqual(self.session.query(Pool).count(), 2)
+ assert self.session.query(Pool).count() == 2
def test_pool_get(self):
pool_command.pool_set(self.parser.parse_args(['pools', 'set', 'foo', '1', 'test']))
@@ -75,17 +77,17 @@ def test_pool_get(self):
def test_pool_delete(self):
pool_command.pool_set(self.parser.parse_args(['pools', 'set', 'foo', '1', 'test']))
pool_command.pool_delete(self.parser.parse_args(['pools', 'delete', 'foo']))
- self.assertEqual(self.session.query(Pool).count(), 1)
+ assert self.session.query(Pool).count() == 1
def test_pool_import_nonexistent(self):
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
pool_command.pool_import(self.parser.parse_args(['pools', 'import', 'nonexistent.json']))
def test_pool_import_invalid_json(self):
with open('pools_import_invalid.json', mode='w') as file:
file.write("not valid json")
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
pool_command.pool_import(self.parser.parse_args(['pools', 'import', 'pools_import_invalid.json']))
def test_pool_import_invalid_pools(self):
@@ -93,7 +95,7 @@ def test_pool_import_invalid_pools(self):
with open('pools_import_invalid.json', mode='w') as file:
json.dump(pool_config_input, file)
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
pool_command.pool_import(self.parser.parse_args(['pools', 'import', 'pools_import_invalid.json']))
def test_pool_import_export(self):
@@ -114,8 +116,6 @@ def test_pool_import_export(self):
with open('pools_export.json', mode='r') as file:
pool_config_output = json.load(file)
- self.assertEqual(
- pool_config_input, pool_config_output, "Input and output pool files are not same"
- )
+ assert pool_config_input == pool_config_output, "Input and output pool files are not same"
os.remove('pools_import.json')
os.remove('pools_export.json')
diff --git a/tests/cli/commands/test_role_command.py b/tests/cli/commands/test_role_command.py
index 167e3a005435e..3ade36be52a72 100644
--- a/tests/cli/commands/test_role_command.py
+++ b/tests/cli/commands/test_role_command.py
@@ -54,25 +54,25 @@ def clear_roles_and_roles(self):
self.appbuilder.sm.delete_role(role_name)
def test_cli_create_roles(self):
- self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamA'))
- self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamB'))
+ assert self.appbuilder.sm.find_role('FakeTeamA') is None
+ assert self.appbuilder.sm.find_role('FakeTeamB') is None
args = self.parser.parse_args(['roles', 'create', 'FakeTeamA', 'FakeTeamB'])
role_command.roles_create(args)
- self.assertIsNotNone(self.appbuilder.sm.find_role('FakeTeamA'))
- self.assertIsNotNone(self.appbuilder.sm.find_role('FakeTeamB'))
+ assert self.appbuilder.sm.find_role('FakeTeamA') is not None
+ assert self.appbuilder.sm.find_role('FakeTeamB') is not None
def test_cli_create_roles_is_reentrant(self):
- self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamA'))
- self.assertIsNone(self.appbuilder.sm.find_role('FakeTeamB'))
+ assert self.appbuilder.sm.find_role('FakeTeamA') is None
+ assert self.appbuilder.sm.find_role('FakeTeamB') is None
args = self.parser.parse_args(['roles', 'create', 'FakeTeamA', 'FakeTeamB'])
role_command.roles_create(args)
- self.assertIsNotNone(self.appbuilder.sm.find_role('FakeTeamA'))
- self.assertIsNotNone(self.appbuilder.sm.find_role('FakeTeamB'))
+ assert self.appbuilder.sm.find_role('FakeTeamA') is not None
+ assert self.appbuilder.sm.find_role('FakeTeamB') is not None
def test_cli_list_roles(self):
self.appbuilder.sm.add_role('FakeTeamA')
@@ -82,8 +82,8 @@ def test_cli_list_roles(self):
role_command.roles_list(self.parser.parse_args(['roles', 'list']))
stdout = stdout.getvalue()
- self.assertIn('FakeTeamA', stdout)
- self.assertIn('FakeTeamB', stdout)
+ assert 'FakeTeamA' in stdout
+ assert 'FakeTeamB' in stdout
def test_cli_list_roles_with_args(self):
role_command.roles_list(self.parser.parse_args(['roles', 'list', '--output', 'yaml']))
diff --git a/tests/cli/commands/test_sync_perm_command.py b/tests/cli/commands/test_sync_perm_command.py
index 8dd3275f4c685..5f7f86e477cf7 100644
--- a/tests/cli/commands/test_sync_perm_command.py
+++ b/tests/cli/commands/test_sync_perm_command.py
@@ -57,7 +57,7 @@ def test_cli_sync_perm(self, dagbag_mock, mock_cached_app):
dagbag_mock.assert_called_once_with(read_dags_from_db=True)
collect_dags_from_db_mock.assert_called_once_with()
- self.assertEqual(2, len(appbuilder.sm.sync_perm_for_dag.mock_calls))
+ assert 2 == len(appbuilder.sm.sync_perm_for_dag.mock_calls)
appbuilder.sm.sync_perm_for_dag.assert_any_call(
'has_access_control', {'Public': {permissions.ACTION_CAN_READ}}
)
diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py
index afa16d149b111..a011ee6308ed7 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -82,7 +82,7 @@ def test_test(self, mock_run_mini_scheduler):
mock_run_mini_scheduler.assert_not_called()
# Check that prints, and log messages, are shown
- self.assertIn("'example_python_operator__print_the_context__20180101'", stdout.getvalue())
+ assert "'example_python_operator__print_the_context__20180101'" in stdout.getvalue()
@mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
def test_run_naive_taskinstance(self, mock_local_job):
@@ -173,8 +173,8 @@ def test_cli_test_with_env_vars(self):
)
)
output = stdout.getvalue()
- self.assertIn('foo=bar', output)
- self.assertIn('AIRFLOW_TEST_MODE=True', output)
+ assert 'foo=bar' in output
+ assert 'AIRFLOW_TEST_MODE=True' in output
def test_cli_run(self):
task_command.task_run(
@@ -192,8 +192,9 @@ def test_cli_run(self):
],
)
def test_cli_run_invalid_raw_option(self, option: str):
- with self.assertRaisesRegex(
- AirflowException, "Option --raw does not work with some of the other options on this command."
+ with pytest.raises(
+ AirflowException,
+ match="Option --raw does not work with some of the other options on this command.",
):
task_command.task_run(
self.parser.parse_args(
@@ -210,7 +211,7 @@ def test_cli_run_invalid_raw_option(self, option: str):
)
def test_cli_run_mutually_exclusive(self):
- with self.assertRaisesRegex(AirflowException, "Option --raw and --local are mutually exclusive."):
+ with pytest.raises(AirflowException, match="Option --raw and --local are mutually exclusive."):
task_command.task_run(
self.parser.parse_args(
[
@@ -260,18 +261,15 @@ def test_task_states_for_dag_run(self):
)
actual_out = json.loads(stdout.getvalue())
- self.assertEqual(len(actual_out), 1)
- self.assertDictEqual(
- actual_out[0],
- {
- 'dag_id': 'example_python_operator',
- 'execution_date': '2016-01-09T00:00:00+00:00',
- 'task_id': 'print_the_context',
- 'state': 'success',
- 'start_date': ti_start.isoformat(),
- 'end_date': ti_end.isoformat(),
- },
- )
+ assert len(actual_out) == 1
+ assert actual_out[0] == {
+ 'dag_id': 'example_python_operator',
+ 'execution_date': '2016-01-09T00:00:00+00:00',
+ 'task_id': 'print_the_context',
+ 'state': 'success',
+ 'start_date': ti_start.isoformat(),
+ 'end_date': ti_end.isoformat(),
+ }
def test_subdag_clear(self):
args = self.parser.parse_args(['tasks', 'clear', 'example_subdag_operator', '--yes'])
@@ -312,7 +310,7 @@ def test_local_run(self):
ti = TaskInstance(task, args.execution_date)
ti.refresh_from_db()
state = ti.current_state()
- self.assertEqual(state, State.SUCCESS)
+ assert state == State.SUCCESS
class TestLogsfromTaskRunCommand(unittest.TestCase):
@@ -359,12 +357,12 @@ def assert_log_line(self, text, logs_list, expect_from_logging_mixin=False):
[2020-06-24 16:47:23,537] {logging_mixin.py:91} INFO - [2020-06-24 16:47:23,536] {python.py:135}
"""
log_lines = [log for log in logs_list if text in log]
- self.assertEqual(len(log_lines), 1)
+ assert len(log_lines) == 1
log_line = log_lines[0]
if not expect_from_logging_mixin:
# Logs from print statement still show with logging_mixing as filename
# Example: [2020-06-24 17:07:00,482] {logging_mixin.py:91} INFO - Log from Print statement
- self.assertNotIn("logging_mixin.py", log_line)
+ assert "logging_mixin.py" not in log_line
return log_line
@unittest.skipIf(not hasattr(os, 'fork'), "Forking not available")
@@ -381,23 +379,21 @@ def test_logging_with_run_task(self):
print(logs) # In case of a test failures this line would show detailed log
logs_list = logs.splitlines()
- self.assertIn("INFO - Started process", logs)
- self.assertIn(f"Subtask {self.task_id}", logs)
- self.assertIn("standard_task_runner.py", logs)
- self.assertIn(
+ assert "INFO - Started process" in logs
+ assert f"Subtask {self.task_id}" in logs
+ assert "standard_task_runner.py" in logs
+ assert (
f"INFO - Running: ['airflow', 'tasks', 'run', '{self.dag_id}', "
- f"'{self.task_id}', '{self.execution_date_str}',",
- logs,
+ f"'{self.task_id}', '{self.execution_date_str}'," in logs
)
self.assert_log_line("Log from DAG Logger", logs_list)
self.assert_log_line("Log from TI Logger", logs_list)
self.assert_log_line("Log from Print statement", logs_list, expect_from_logging_mixin=True)
- self.assertIn(
+ assert (
f"INFO - Marking task as SUCCESS. dag_id={self.dag_id}, "
- f"task_id={self.task_id}, execution_date=20170101T000000",
- logs,
+ f"task_id={self.task_id}, execution_date=20170101T000000" in logs
)
@mock.patch("airflow.task.task_runner.standard_task_runner.CAN_FORK", False)
@@ -413,21 +409,19 @@ def test_logging_with_run_task_subprocess(self):
print(logs) # In case of a test failures this line would show detailed log
logs_list = logs.splitlines()
- self.assertIn(f"Subtask {self.task_id}", logs)
- self.assertIn("base_task_runner.py", logs)
+ assert f"Subtask {self.task_id}" in logs
+ assert "base_task_runner.py" in logs
self.assert_log_line("Log from DAG Logger", logs_list)
self.assert_log_line("Log from TI Logger", logs_list)
self.assert_log_line("Log from Print statement", logs_list, expect_from_logging_mixin=True)
- self.assertIn(
+ assert (
f"INFO - Running: ['airflow', 'tasks', 'run', '{self.dag_id}', "
- f"'{self.task_id}', '{self.execution_date_str}',",
- logs,
+ f"'{self.task_id}', '{self.execution_date_str}'," in logs
)
- self.assertIn(
+ assert (
f"INFO - Marking task as SUCCESS. dag_id={self.dag_id}, "
- f"task_id={self.task_id}, execution_date=20170101T000000",
- logs,
+ f"task_id={self.task_id}, execution_date=20170101T000000" in logs
)
def test_log_file_template_with_run_task(self):
@@ -543,7 +537,7 @@ def test_run_ignores_all_dependencies(self):
ti_dependent0 = TaskInstance(task=dag.get_task(task0_id), execution_date=DEFAULT_DATE)
ti_dependent0.refresh_from_db()
- self.assertEqual(ti_dependent0.state, State.FAILED)
+ assert ti_dependent0.state == State.FAILED
task1_id = 'test_run_dependency_task'
args1 = [
@@ -560,7 +554,7 @@ def test_run_ignores_all_dependencies(self):
task=dag.get_task(task1_id), execution_date=DEFAULT_DATE + timedelta(days=1)
)
ti_dependency.refresh_from_db()
- self.assertEqual(ti_dependency.state, State.FAILED)
+ assert ti_dependency.state == State.FAILED
task2_id = 'test_run_dependent_task'
args2 = [
@@ -577,4 +571,4 @@ def test_run_ignores_all_dependencies(self):
task=dag.get_task(task2_id), execution_date=DEFAULT_DATE + timedelta(days=1)
)
ti_dependent.refresh_from_db()
- self.assertEqual(ti_dependent.state, State.SUCCESS)
+ assert ti_dependent.state == State.SUCCESS
diff --git a/tests/cli/commands/test_user_command.py b/tests/cli/commands/test_user_command.py
index f5bf1671802c0..5ca3bb08494a6 100644
--- a/tests/cli/commands/test_user_command.py
+++ b/tests/cli/commands/test_user_command.py
@@ -158,7 +158,7 @@ def test_cli_list_users(self):
user_command.users_list(self.parser.parse_args(['users', 'list']))
stdout = stdout.getvalue()
for i in range(0, 3):
- self.assertIn(f'user{i}', stdout)
+ assert f'user{i}' in stdout
def test_cli_list_users_with_args(self):
user_command.users_list(self.parser.parse_args(['users', 'list', '--output', 'json']))
@@ -166,11 +166,11 @@ def test_cli_list_users_with_args(self):
def test_cli_import_users(self):
def assert_user_in_roles(email, roles):
for role in roles:
- self.assertTrue(_does_user_belong_to_role(self.appbuilder, email, role))
+ assert _does_user_belong_to_role(self.appbuilder, email, role)
def assert_user_not_in_roles(email, roles):
for role in roles:
- self.assertFalse(_does_user_belong_to_role(self.appbuilder, email, role))
+ assert not _does_user_belong_to_role(self.appbuilder, email, role)
assert_user_not_in_roles(TEST_USER1_EMAIL, ['Admin', 'Op'])
assert_user_not_in_roles(TEST_USER2_EMAIL, ['Public'])
@@ -250,8 +250,8 @@ def find_by_username(username):
matches[0].pop('id') # this key not required for import
return matches[0]
- self.assertEqual(find_by_username('imported_user1'), user1)
- self.assertEqual(find_by_username('imported_user2'), user2)
+ assert find_by_username('imported_user1') == user1
+ assert find_by_username('imported_user2') == user2
def _import_users_from_file(self, user_list):
json_file_content = json.dumps(user_list)
@@ -291,18 +291,16 @@ def test_cli_add_user_role(self):
)
user_command.users_create(args)
- self.assertFalse(
- _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Op'),
- "User should not yet be a member of role 'Op'",
- )
+ assert not _does_user_belong_to_role(
+ appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Op'
+ ), "User should not yet be a member of role 'Op'"
args = self.parser.parse_args(['users', 'add-role', '--username', 'test4', '--role', 'Op'])
user_command.users_manage_role(args, remove=False)
- self.assertTrue(
- _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Op'),
- "User should have been added to role 'Op'",
- )
+ assert _does_user_belong_to_role(
+ appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Op'
+ ), "User should have been added to role 'Op'"
def test_cli_remove_user_role(self):
args = self.parser.parse_args(
@@ -324,15 +322,13 @@ def test_cli_remove_user_role(self):
)
user_command.users_create(args)
- self.assertTrue(
- _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Viewer'),
- "User should have been created with role 'Viewer'",
- )
+ assert _does_user_belong_to_role(
+ appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Viewer'
+ ), "User should have been created with role 'Viewer'"
args = self.parser.parse_args(['users', 'remove-role', '--username', 'test4', '--role', 'Viewer'])
user_command.users_manage_role(args, remove=True)
- self.assertFalse(
- _does_user_belong_to_role(appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Viewer'),
- "User should have been removed from role 'Viewer'",
- )
+ assert not _does_user_belong_to_role(
+ appbuilder=self.appbuilder, email=TEST_USER1_EMAIL, rolename='Viewer'
+ ), "User should have been removed from role 'Viewer'"
diff --git a/tests/cli/commands/test_variable_command.py b/tests/cli/commands/test_variable_command.py
index c9e9318c64d49..8b64ff8a66899 100644
--- a/tests/cli/commands/test_variable_command.py
+++ b/tests/cli/commands/test_variable_command.py
@@ -22,6 +22,8 @@
import unittest.mock
from contextlib import redirect_stdout
+import pytest
+
from airflow import models
from airflow.cli import cli_parser
from airflow.cli.commands import variable_command
@@ -44,25 +46,26 @@ def tearDown(self):
def test_variables_set(self):
"""Test variable_set command"""
variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'foo', 'bar']))
- self.assertIsNotNone(Variable.get("foo"))
- self.assertRaises(KeyError, Variable.get, "foo1")
+ assert Variable.get("foo") is not None
+ with pytest.raises(KeyError):
+ Variable.get("foo1")
def test_variables_get(self):
Variable.set('foo', {'foo': 'bar'}, serialize_json=True)
with redirect_stdout(io.StringIO()) as stdout:
variable_command.variables_get(self.parser.parse_args(['variables', 'get', 'foo']))
- self.assertEqual('{\n "foo": "bar"\n}\n', stdout.getvalue())
+ assert '{\n "foo": "bar"\n}\n' == stdout.getvalue()
def test_get_variable_default_value(self):
with redirect_stdout(io.StringIO()) as stdout:
variable_command.variables_get(
self.parser.parse_args(['variables', 'get', 'baz', '--default', 'bar'])
)
- self.assertEqual("bar\n", stdout.getvalue())
+ assert "bar\n" == stdout.getvalue()
def test_get_variable_missing_variable(self):
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
variable_command.variables_get(self.parser.parse_args(['variables', 'get', 'no-existing-VAR']))
def test_variables_set_different_types(self):
@@ -95,14 +98,14 @@ def test_variables_set_different_types(self):
)
# Assert value
- self.assertEqual({'foo': 'oops'}, Variable.get('dict', deserialize_json=True))
- self.assertEqual(['oops'], Variable.get('list', deserialize_json=True))
- self.assertEqual('hello string', Variable.get('str')) # cannot json.loads(str)
- self.assertEqual(42, Variable.get('int', deserialize_json=True))
- self.assertEqual(42.0, Variable.get('float', deserialize_json=True))
- self.assertEqual(True, Variable.get('true', deserialize_json=True))
- self.assertEqual(False, Variable.get('false', deserialize_json=True))
- self.assertEqual(None, Variable.get('null', deserialize_json=True))
+ assert {'foo': 'oops'} == Variable.get('dict', deserialize_json=True)
+ assert ['oops'] == Variable.get('list', deserialize_json=True)
+ assert 'hello string' == Variable.get('str') # cannot json.loads(str)
+ assert 42 == Variable.get('int', deserialize_json=True)
+ assert 42.0 == Variable.get('float', deserialize_json=True)
+ assert Variable.get('true', deserialize_json=True) is True
+ assert Variable.get('false', deserialize_json=True) is False
+ assert Variable.get('null', deserialize_json=True) is None
os.remove('variables_types.json')
@@ -115,11 +118,12 @@ def test_variables_delete(self):
"""Test variable_delete command"""
variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'foo', 'bar']))
variable_command.variables_delete(self.parser.parse_args(['variables', 'delete', 'foo']))
- self.assertRaises(KeyError, Variable.get, "foo")
+ with pytest.raises(KeyError):
+ Variable.get("foo")
def test_variables_import(self):
"""Test variables_import command"""
- with self.assertRaisesRegex(SystemExit, r"Invalid variables file"):
+ with pytest.raises(SystemExit, match=r"Invalid variables file"):
variable_command.variables_import(self.parser.parse_args(['variables', 'import', os.devnull]))
def test_variables_export(self):
@@ -143,14 +147,14 @@ def test_variables_isolation(self):
variable_command.variables_delete(self.parser.parse_args(['variables', 'delete', 'foo']))
variable_command.variables_import(self.parser.parse_args(['variables', 'import', tmp1.name]))
- self.assertEqual('original', Variable.get('bar'))
- self.assertEqual('{\n "foo": "bar"\n}', Variable.get('foo'))
+ assert 'original' == Variable.get('bar')
+ assert '{\n "foo": "bar"\n}' == Variable.get('foo')
# Second export
variable_command.variables_export(self.parser.parse_args(['variables', 'export', tmp2.name]))
second_exp = open(tmp2.name)
- self.assertEqual(first_exp.read(), second_exp.read())
+ assert first_exp.read() == second_exp.read()
# Clean up files
second_exp.close()
diff --git a/tests/cli/commands/test_version_command.py b/tests/cli/commands/test_version_command.py
index ad4f5a830c0c1..e4454aaef55f3 100644
--- a/tests/cli/commands/test_version_command.py
+++ b/tests/cli/commands/test_version_command.py
@@ -32,4 +32,4 @@ def setUpClass(cls):
def test_cli_version(self):
with redirect_stdout(io.StringIO()) as stdout:
airflow.cli.commands.version_command.version(self.parser.parse_args(['version']))
- self.assertIn(version, stdout.getvalue())
+ assert version in stdout.getvalue()
diff --git a/tests/cli/commands/test_webserver_command.py b/tests/cli/commands/test_webserver_command.py
index 1d7a6703df121..a73beadbe2a36 100644
--- a/tests/cli/commands/test_webserver_command.py
+++ b/tests/cli/commands/test_webserver_command.py
@@ -85,7 +85,7 @@ def test_should_start_new_workers_when_refresh_interval_has_passed(self, mock_sl
self.monitor._spawn_new_workers.assert_called_once_with(2) # pylint: disable=no-member
self.monitor._kill_old_workers.assert_not_called() # pylint: disable=no-member
self.monitor._reload_gunicorn.assert_not_called() # pylint: disable=no-member
- self.assertAlmostEqual(self.monitor._last_refresh_time, time.monotonic(), delta=5)
+ assert abs(self.monitor._last_refresh_time - time.monotonic()) < 5
@mock.patch('airflow.cli.commands.webserver_command.sleep')
def test_should_reload_when_plugin_has_been_changed(self, mock_sleep):
@@ -112,7 +112,7 @@ def test_should_reload_when_plugin_has_been_changed(self, mock_sleep):
self.monitor._spawn_new_workers.assert_not_called() # pylint: disable=no-member
self.monitor._kill_old_workers.assert_not_called() # pylint: disable=no-member
self.monitor._reload_gunicorn.assert_called_once_with() # pylint: disable=no-member
- self.assertAlmostEqual(self.monitor._last_refresh_time, time.monotonic(), delta=5)
+ assert abs(self.monitor._last_refresh_time - time.monotonic()) < 5
class TestGunicornMonitorGeneratePluginState(unittest.TestCase):
@@ -144,32 +144,32 @@ def test_should_detect_changes_in_directory(self):
state_a = monitor._generate_plugin_state()
state_b = monitor._generate_plugin_state()
- self.assertEqual(state_a, state_b)
- self.assertEqual(3, len(state_a))
+ assert state_a == state_b
+ assert 3 == len(state_a)
# Should detect new file
self._prepare_test_file(f"{tempdir}/file4.txt", 400)
state_c = monitor._generate_plugin_state()
- self.assertNotEqual(state_b, state_c)
- self.assertEqual(4, len(state_c))
+ assert state_b != state_c
+ assert 4 == len(state_c)
# Should detect changes in files
self._prepare_test_file(f"{tempdir}/file4.txt", 450)
state_d = monitor._generate_plugin_state()
- self.assertNotEqual(state_c, state_d)
- self.assertEqual(4, len(state_d))
+ assert state_c != state_d
+ assert 4 == len(state_d)
# Should support large files
self._prepare_test_file(f"{tempdir}/file4.txt", 4000000)
state_d = monitor._generate_plugin_state()
- self.assertNotEqual(state_c, state_d)
- self.assertEqual(4, len(state_d))
+ assert state_c != state_d
+ assert 4 == len(state_d)
class TestCLIGetNumReadyWorkersRunning(unittest.TestCase):
@@ -195,27 +195,27 @@ def test_ready_prefix_on_cmdline(self):
self.process.children.return_value = [self.child]
with mock.patch('psutil.Process', return_value=self.process):
- self.assertEqual(self.monitor._get_num_ready_workers_running(), 1)
+ assert self.monitor._get_num_ready_workers_running() == 1
def test_ready_prefix_on_cmdline_no_children(self):
self.process.children.return_value = []
with mock.patch('psutil.Process', return_value=self.process):
- self.assertEqual(self.monitor._get_num_ready_workers_running(), 0)
+ assert self.monitor._get_num_ready_workers_running() == 0
def test_ready_prefix_on_cmdline_zombie(self):
self.child.cmdline.return_value = []
self.process.children.return_value = [self.child]
with mock.patch('psutil.Process', return_value=self.process):
- self.assertEqual(self.monitor._get_num_ready_workers_running(), 0)
+ assert self.monitor._get_num_ready_workers_running() == 0
def test_ready_prefix_on_cmdline_dead_process(self):
self.child.cmdline.side_effect = psutil.NoSuchProcess(11347)
self.process.children.return_value = [self.child]
with mock.patch('psutil.Process', return_value=self.process):
- self.assertEqual(self.monitor._get_num_ready_workers_running(), 0)
+ assert self.monitor._get_num_ready_workers_running() == 0
class TestCliWebServer(unittest.TestCase):
@@ -275,7 +275,7 @@ def test_cli_webserver_foreground(self):
):
# Run webserver in foreground and terminate it.
proc = subprocess.Popen(["airflow", "webserver"])
- self.assertEqual(None, proc.poll())
+ assert proc.poll() is None
# Wait for process
time.sleep(10)
@@ -284,7 +284,7 @@ def test_cli_webserver_foreground(self):
proc.terminate()
# -15 - the server was stopped before it started
# 0 - the server terminated correctly
- self.assertIn(proc.wait(60), (-15, 0))
+ assert proc.wait(60) in (-15, 0)
def test_cli_webserver_foreground_with_pid(self):
with tempfile.TemporaryDirectory(prefix='tmp-pid') as tmpdir:
@@ -296,14 +296,14 @@ def test_cli_webserver_foreground_with_pid(self):
AIRFLOW__WEBSERVER__WORKERS="1",
):
proc = subprocess.Popen(["airflow", "webserver", "--pid", pidfile])
- self.assertEqual(None, proc.poll())
+ assert proc.poll() is None
# Check the file specified by --pid option exists
self._wait_pidfile(pidfile)
# Terminate webserver
proc.terminate()
- self.assertEqual(0, proc.wait(60))
+ assert 0 == proc.wait(60)
@pytest.mark.quarantined
def test_cli_webserver_background(self):
@@ -335,21 +335,19 @@ def test_cli_webserver_background(self):
logfile,
]
)
- self.assertEqual(None, proc.poll())
+ assert proc.poll() is None
pid_monitor = self._wait_pidfile(pidfile_monitor)
self._wait_pidfile(pidfile_webserver)
# Assert that gunicorn and its monitor are launched.
- self.assertEqual(
- 0, subprocess.Popen(["pgrep", "-f", "-c", "airflow webserver --daemon"]).wait()
- )
- self.assertEqual(0, subprocess.Popen(["pgrep", "-c", "-f", "gunicorn: master"]).wait())
+ assert 0 == subprocess.Popen(["pgrep", "-f", "-c", "airflow webserver --daemon"]).wait()
+ assert 0 == subprocess.Popen(["pgrep", "-c", "-f", "gunicorn: master"]).wait()
# Terminate monitor process.
proc = psutil.Process(pid_monitor)
proc.terminate()
- self.assertIn(proc.wait(120), (0, None))
+ assert proc.wait(120) in (0, None)
self._check_processes()
except Exception:
@@ -367,20 +365,18 @@ def test_cli_webserver_shutdown_when_gunicorn_master_is_killed(self, _):
# Shorten timeout so that this test doesn't take too long time
args = self.parser.parse_args(['webserver'])
with conf_vars({('webserver', 'web_server_master_timeout'): '10'}):
- with self.assertRaises(SystemExit) as e:
+ with pytest.raises(SystemExit) as ctx:
webserver_command.webserver(args)
- self.assertEqual(e.exception.code, 1)
+ assert ctx.value.code == 1
def test_cli_webserver_debug(self):
env = os.environ.copy()
proc = psutil.Popen(["airflow", "webserver", "--debug"], env=env)
time.sleep(3) # wait for webserver to start
return_code = proc.poll()
- self.assertEqual(
- None, return_code, f"webserver terminated with return code {return_code} in debug mode"
- )
+ assert return_code is None, f"webserver terminated with return code {return_code} in debug mode"
proc.terminate()
- self.assertEqual(-15, proc.wait(60))
+ assert -15 == proc.wait(60)
def test_cli_webserver_access_log_format(self):
@@ -409,7 +405,7 @@ def test_cli_webserver_access_log_format(self):
access_logformat,
]
)
- self.assertEqual(None, proc.poll())
+ assert proc.poll() is None
# Wait for webserver process
time.sleep(10)
@@ -419,9 +415,9 @@ def test_cli_webserver_access_log_format(self):
try:
file = open(access_logfile)
log = json.loads(file.read())
- self.assertEqual('127.0.0.1', log.get('remote_ip'))
- self.assertEqual(len(log), 9)
- self.assertEqual('GET', log.get('request_method'))
+ assert '127.0.0.1' == log.get('remote_ip')
+ assert len(log) == 9
+ assert 'GET' == log.get('request_method')
except OSError:
print("access log file not found at " + access_logfile)
@@ -430,5 +426,5 @@ def test_cli_webserver_access_log_format(self):
proc.terminate()
# -15 - the server was stopped before it started
# 0 - the server terminated correctly
- self.assertIn(proc.wait(60), (-15, 0))
+ assert proc.wait(60) in (-15, 0)
self._check_processes()
diff --git a/tests/cli/test_cli_parser.py b/tests/cli/test_cli_parser.py
index 4a2dd8524420f..1c2e2aaab2e6d 100644
--- a/tests/cli/test_cli_parser.py
+++ b/tests/cli/test_cli_parser.py
@@ -24,6 +24,8 @@
from collections import Counter
from unittest import TestCase
+import pytest
+
from airflow.cli import cli_parser
# Can not be `--snake_case` or contain uppercase letter
@@ -43,7 +45,7 @@ def test_arg_option_long_only(self):
arg for arg in cli_args.values() if len(arg.flags) == 1 and arg.flags[0].startswith("-")
]
for arg in optional_long:
- self.assertIsNone(ILLEGAL_LONG_OPTION_PATTERN.match(arg.flags[0]), f"{arg.flags[0]} is not match")
+ assert ILLEGAL_LONG_OPTION_PATTERN.match(arg.flags[0]) is None, f"{arg.flags[0]} is not match"
def test_arg_option_mix_short_long(self):
"""
@@ -53,10 +55,8 @@ def test_arg_option_mix_short_long(self):
arg for arg in cli_args.values() if len(arg.flags) == 2 and arg.flags[0].startswith("-")
]
for arg in optional_mix:
- self.assertIsNotNone(
- LEGAL_SHORT_OPTION_PATTERN.match(arg.flags[0]), f"{arg.flags[0]} is not match"
- )
- self.assertIsNone(ILLEGAL_LONG_OPTION_PATTERN.match(arg.flags[1]), f"{arg.flags[1]} is not match")
+ assert LEGAL_SHORT_OPTION_PATTERN.match(arg.flags[0]) is not None, f"{arg.flags[0]} is not match"
+ assert ILLEGAL_LONG_OPTION_PATTERN.match(arg.flags[1]) is None, f"{arg.flags[1]} is not match"
def test_subcommand_conflict(self):
"""
@@ -69,9 +69,7 @@ def test_subcommand_conflict(self):
}
for group_name, sub in subcommand.items():
name = [command.name.lower() for command in sub]
- self.assertEqual(
- len(name), len(set(name)), f"Command group {group_name} have conflict subcommand"
- )
+ assert len(name) == len(set(name)), f"Command group {group_name} have conflict subcommand"
def test_subcommand_arg_name_conflict(self):
"""
@@ -85,10 +83,8 @@ def test_subcommand_arg_name_conflict(self):
for group, command in subcommand.items():
for com in command:
conflict_arg = [arg for arg, count in Counter(com.args).items() if count > 1]
- self.assertListEqual(
- [],
- conflict_arg,
- f"Command group {group} function {com.name} have " f"conflict args name {conflict_arg}",
+ assert [] == conflict_arg, (
+ f"Command group {group} function {com.name} have " f"conflict args name {conflict_arg}"
)
def test_subcommand_arg_flag_conflict(self):
@@ -106,31 +102,25 @@ def test_subcommand_arg_flag_conflict(self):
a.flags[0] for a in com.args if (len(a.flags) == 1 and not a.flags[0].startswith("-"))
]
conflict_position = [arg for arg, count in Counter(position).items() if count > 1]
- self.assertListEqual(
- [],
- conflict_position,
+ assert [] == conflict_position, (
f"Command group {group} function {com.name} have conflict "
- f"position flags {conflict_position}",
+ f"position flags {conflict_position}"
)
long_option = [
a.flags[0] for a in com.args if (len(a.flags) == 1 and a.flags[0].startswith("-"))
] + [a.flags[1] for a in com.args if len(a.flags) == 2]
conflict_long_option = [arg for arg, count in Counter(long_option).items() if count > 1]
- self.assertListEqual(
- [],
- conflict_long_option,
+ assert [] == conflict_long_option, (
f"Command group {group} function {com.name} have conflict "
- f"long option flags {conflict_long_option}",
+ f"long option flags {conflict_long_option}"
)
short_option = [a.flags[0] for a in com.args if len(a.flags) == 2]
conflict_short_option = [arg for arg, count in Counter(short_option).items() if count > 1]
- self.assertEqual(
- [],
- conflict_short_option,
+ assert [] == conflict_short_option, (
f"Command group {group} function {com.name} have conflict "
- f"short option flags {conflict_short_option}",
+ f"short option flags {conflict_short_option}"
)
def test_falsy_default_value(self):
@@ -139,20 +129,20 @@ def test_falsy_default_value(self):
arg.add_to_parser(parser)
args = parser.parse_args(['--test', '10'])
- self.assertEqual(args.test, 10)
+ assert args.test == 10
args = parser.parse_args([])
- self.assertEqual(args.test, 0)
+ assert args.test == 0
def test_commands_and_command_group_sections(self):
parser = cli_parser.get_parser()
with contextlib.redirect_stdout(io.StringIO()) as stdout:
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
parser.parse_args(['--help'])
stdout = stdout.getvalue()
- self.assertIn("Commands", stdout)
- self.assertIn("Groups", stdout)
+ assert "Commands" in stdout
+ assert "Groups" in stdout
def test_should_display_help(self):
parser = cli_parser.get_parser()
@@ -167,12 +157,12 @@ def test_should_display_help(self):
)
]
for cmd_args in all_command_as_args:
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
parser.parse_args([*cmd_args, '--help'])
def test_positive_int(self):
- self.assertEqual(1, cli_parser.positive_int('1'))
+ assert 1 == cli_parser.positive_int('1')
- with self.assertRaises(argparse.ArgumentTypeError):
+ with pytest.raises(argparse.ArgumentTypeError):
cli_parser.positive_int('0')
cli_parser.positive_int('-1')
diff --git a/tests/core/test_config_templates.py b/tests/core/test_config_templates.py
index 42ba99133028a..2efa838741abd 100644
--- a/tests/core/test_config_templates.py
+++ b/tests/core/test_config_templates.py
@@ -83,7 +83,7 @@ class TestAirflowCfg(unittest.TestCase):
def test_should_be_ascii_file(self, filename: str):
with open(os.path.join(CONFIG_TEMPLATES_FOLDER, filename), "rb") as f:
content = f.read().decode("ascii")
- self.assertTrue(content)
+ assert content
@parameterized.expand(
[
@@ -102,4 +102,4 @@ def test_should_be_ini_file(self, filename: str, expected_sections):
config = configparser.ConfigParser()
config.read(filepath)
- self.assertEqual(expected_sections, config.sections())
+ assert expected_sections == config.sections()
diff --git a/tests/core/test_configuration.py b/tests/core/test_configuration.py
index 338c3e807e72b..a2798373afffd 100644
--- a/tests/core/test_configuration.py
+++ b/tests/core/test_configuration.py
@@ -24,6 +24,8 @@
from collections import OrderedDict
from unittest import mock
+import pytest
+
from airflow import configuration
from airflow.configuration import (
DEFAULT_CONFIG,
@@ -54,44 +56,44 @@ def test_airflow_home_default(self):
with unittest.mock.patch.dict('os.environ'):
if 'AIRFLOW_HOME' in os.environ:
del os.environ['AIRFLOW_HOME']
- self.assertEqual(get_airflow_home(), expand_env_var('~/airflow'))
+ assert get_airflow_home() == expand_env_var('~/airflow')
def test_airflow_home_override(self):
with unittest.mock.patch.dict('os.environ', AIRFLOW_HOME='/path/to/airflow'):
- self.assertEqual(get_airflow_home(), '/path/to/airflow')
+ assert get_airflow_home() == '/path/to/airflow'
def test_airflow_config_default(self):
with unittest.mock.patch.dict('os.environ'):
if 'AIRFLOW_CONFIG' in os.environ:
del os.environ['AIRFLOW_CONFIG']
- self.assertEqual(get_airflow_config('/home/airflow'), expand_env_var('/home/airflow/airflow.cfg'))
+ assert get_airflow_config('/home/airflow') == expand_env_var('/home/airflow/airflow.cfg')
def test_airflow_config_override(self):
with unittest.mock.patch.dict('os.environ', AIRFLOW_CONFIG='/path/to/airflow/airflow.cfg'):
- self.assertEqual(get_airflow_config('/home//airflow'), '/path/to/airflow/airflow.cfg')
+ assert get_airflow_config('/home//airflow') == '/path/to/airflow/airflow.cfg'
@conf_vars({("core", "percent"): "with%%inside"})
def test_case_sensitivity(self):
# section and key are case insensitive for get method
# note: this is not the case for as_dict method
- self.assertEqual(conf.get("core", "percent"), "with%inside")
- self.assertEqual(conf.get("core", "PERCENT"), "with%inside")
- self.assertEqual(conf.get("CORE", "PERCENT"), "with%inside")
+ assert conf.get("core", "percent") == "with%inside"
+ assert conf.get("core", "PERCENT") == "with%inside"
+ assert conf.get("CORE", "PERCENT") == "with%inside"
def test_env_var_config(self):
opt = conf.get('testsection', 'testkey')
- self.assertEqual(opt, 'testvalue')
+ assert opt == 'testvalue'
opt = conf.get('testsection', 'testpercent')
- self.assertEqual(opt, 'with%percent')
+ assert opt == 'with%percent'
- self.assertTrue(conf.has_option('testsection', 'testkey'))
+ assert conf.has_option('testsection', 'testkey')
with unittest.mock.patch.dict(
'os.environ', AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY='nested'
):
opt = conf.get('kubernetes_environment_variables', 'AIRFLOW__TESTSECTION__TESTKEY')
- self.assertEqual(opt, 'nested')
+ assert opt == 'nested'
@mock.patch.dict(
'os.environ', AIRFLOW__KUBERNETES_ENVIRONMENT_VARIABLES__AIRFLOW__TESTSECTION__TESTKEY='nested'
@@ -101,42 +103,40 @@ def test_conf_as_dict(self):
cfg_dict = conf.as_dict()
# test that configs are picked up
- self.assertEqual(cfg_dict['core']['unit_test_mode'], 'True')
+ assert cfg_dict['core']['unit_test_mode'] == 'True'
- self.assertEqual(cfg_dict['core']['percent'], 'with%inside')
+ assert cfg_dict['core']['percent'] == 'with%inside'
# test env vars
- self.assertEqual(cfg_dict['testsection']['testkey'], '< hidden >')
- self.assertEqual(
- cfg_dict['kubernetes_environment_variables']['AIRFLOW__TESTSECTION__TESTKEY'], '< hidden >'
- )
+ assert cfg_dict['testsection']['testkey'] == '< hidden >'
+ assert cfg_dict['kubernetes_environment_variables']['AIRFLOW__TESTSECTION__TESTKEY'] == '< hidden >'
def test_conf_as_dict_source(self):
# test display_source
cfg_dict = conf.as_dict(display_source=True)
- self.assertEqual(cfg_dict['core']['load_examples'][1], 'airflow.cfg')
- self.assertEqual(cfg_dict['core']['load_default_connections'][1], 'airflow.cfg')
- self.assertEqual(cfg_dict['testsection']['testkey'], ('< hidden >', 'env var'))
+ assert cfg_dict['core']['load_examples'][1] == 'airflow.cfg'
+ assert cfg_dict['core']['load_default_connections'][1] == 'airflow.cfg'
+ assert cfg_dict['testsection']['testkey'] == ('< hidden >', 'env var')
def test_conf_as_dict_sensitive(self):
# test display_sensitive
cfg_dict = conf.as_dict(display_sensitive=True)
- self.assertEqual(cfg_dict['testsection']['testkey'], 'testvalue')
- self.assertEqual(cfg_dict['testsection']['testpercent'], 'with%percent')
+ assert cfg_dict['testsection']['testkey'] == 'testvalue'
+ assert cfg_dict['testsection']['testpercent'] == 'with%percent'
# test display_source and display_sensitive
cfg_dict = conf.as_dict(display_sensitive=True, display_source=True)
- self.assertEqual(cfg_dict['testsection']['testkey'], ('testvalue', 'env var'))
+ assert cfg_dict['testsection']['testkey'] == ('testvalue', 'env var')
@conf_vars({("core", "percent"): "with%%inside"})
def test_conf_as_dict_raw(self):
# test display_sensitive
cfg_dict = conf.as_dict(raw=True, display_sensitive=True)
- self.assertEqual(cfg_dict['testsection']['testkey'], 'testvalue')
+ assert cfg_dict['testsection']['testkey'] == 'testvalue'
# Values with '%' in them should be escaped
- self.assertEqual(cfg_dict['testsection']['testpercent'], 'with%%percent')
- self.assertEqual(cfg_dict['core']['percent'], 'with%%inside')
+ assert cfg_dict['testsection']['testpercent'] == 'with%%percent'
+ assert cfg_dict['core']['percent'] == 'with%%inside'
def test_conf_as_dict_exclude_env(self):
# test display_sensitive
@@ -144,7 +144,7 @@ def test_conf_as_dict_exclude_env(self):
# Since testsection is only created from env vars, it shouldn't be
# present at all if we don't ask for env vars to be included.
- self.assertNotIn('testsection', cfg_dict)
+ assert 'testsection' not in cfg_dict
def test_command_precedence(self):
test_config = '''[test]
@@ -167,35 +167,35 @@ def test_command_precedence(self):
('test', 'key2'),
('test', 'key4'),
}
- self.assertEqual('hello', test_conf.get('test', 'key1'))
- self.assertEqual('cmd_result', test_conf.get('test', 'key2'))
- self.assertEqual('airflow', test_conf.get('test', 'key3'))
- self.assertEqual('key4_result', test_conf.get('test', 'key4'))
- self.assertEqual('value6', test_conf.get('another', 'key6'))
-
- self.assertEqual('hello', test_conf.get('test', 'key1', fallback='fb'))
- self.assertEqual('value6', test_conf.get('another', 'key6', fallback='fb'))
- self.assertEqual('fb', test_conf.get('another', 'key7', fallback='fb'))
- self.assertEqual(True, test_conf.getboolean('another', 'key8_boolean', fallback='True'))
- self.assertEqual(10, test_conf.getint('another', 'key8_int', fallback='10'))
- self.assertEqual(1.0, test_conf.getfloat('another', 'key8_float', fallback='1'))
-
- self.assertTrue(test_conf.has_option('test', 'key1'))
- self.assertTrue(test_conf.has_option('test', 'key2'))
- self.assertTrue(test_conf.has_option('test', 'key3'))
- self.assertTrue(test_conf.has_option('test', 'key4'))
- self.assertFalse(test_conf.has_option('test', 'key5'))
- self.assertTrue(test_conf.has_option('another', 'key6'))
+ assert 'hello' == test_conf.get('test', 'key1')
+ assert 'cmd_result' == test_conf.get('test', 'key2')
+ assert 'airflow' == test_conf.get('test', 'key3')
+ assert 'key4_result' == test_conf.get('test', 'key4')
+ assert 'value6' == test_conf.get('another', 'key6')
+
+ assert 'hello' == test_conf.get('test', 'key1', fallback='fb')
+ assert 'value6' == test_conf.get('another', 'key6', fallback='fb')
+ assert 'fb' == test_conf.get('another', 'key7', fallback='fb')
+ assert test_conf.getboolean('another', 'key8_boolean', fallback='True') is True
+ assert 10 == test_conf.getint('another', 'key8_int', fallback='10')
+ assert 1.0 == test_conf.getfloat('another', 'key8_float', fallback='1')
+
+ assert test_conf.has_option('test', 'key1')
+ assert test_conf.has_option('test', 'key2')
+ assert test_conf.has_option('test', 'key3')
+ assert test_conf.has_option('test', 'key4')
+ assert not test_conf.has_option('test', 'key5')
+ assert test_conf.has_option('another', 'key6')
cfg_dict = test_conf.as_dict(display_sensitive=True)
- self.assertEqual('cmd_result', cfg_dict['test']['key2'])
- self.assertNotIn('key2_cmd', cfg_dict['test'])
+ assert 'cmd_result' == cfg_dict['test']['key2']
+ assert 'key2_cmd' not in cfg_dict['test']
# If we exclude _cmds then we should still see the commands to run, not
# their values
cfg_dict = test_conf.as_dict(include_cmds=False, display_sensitive=True)
- self.assertNotIn('key4', cfg_dict['test'])
- self.assertEqual('printf key4_result', cfg_dict['test']['key4_cmd'])
+ assert 'key4' not in cfg_dict['test']
+ assert 'printf key4_result' == cfg_dict['test']['key4_cmd']
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@conf_vars(
@@ -240,9 +240,7 @@ def test_config_from_secret_backend(self, mock_hvac):
('test', 'sql_alchemy_conn'),
}
- self.assertEqual(
- 'sqlite:////Users/airflow/airflow/airflow.db', test_conf.get('test', 'sql_alchemy_conn')
- )
+ assert 'sqlite:////Users/airflow/airflow/airflow.db' == test_conf.get('test', 'sql_alchemy_conn')
def test_getboolean(self):
"""Test AirflowConfigParser.getboolean"""
@@ -264,22 +262,22 @@ def test_getboolean(self):
key8 = true #123
"""
test_conf = AirflowConfigParser(default_config=test_config)
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowConfigException,
- re.escape(
+ match=re.escape(
'Failed to convert value to bool. Please check "key1" key in "type_validation" section. '
'Current value: "non_bool_value".'
),
):
test_conf.getboolean('type_validation', 'key1')
- self.assertTrue(isinstance(test_conf.getboolean('true', 'key3'), bool))
- self.assertEqual(True, test_conf.getboolean('true', 'key2'))
- self.assertEqual(True, test_conf.getboolean('true', 'key3'))
- self.assertEqual(True, test_conf.getboolean('true', 'key4'))
- self.assertEqual(False, test_conf.getboolean('false', 'key5'))
- self.assertEqual(False, test_conf.getboolean('false', 'key6'))
- self.assertEqual(False, test_conf.getboolean('false', 'key7'))
- self.assertEqual(True, test_conf.getboolean('inline-comment', 'key8'))
+ assert isinstance(test_conf.getboolean('true', 'key3'), bool)
+ assert test_conf.getboolean('true', 'key2') is True
+ assert test_conf.getboolean('true', 'key3') is True
+ assert test_conf.getboolean('true', 'key4') is True
+ assert test_conf.getboolean('false', 'key5') is False
+ assert test_conf.getboolean('false', 'key6') is False
+ assert test_conf.getboolean('false', 'key7') is False
+ assert test_conf.getboolean('inline-comment', 'key8') is True
def test_getint(self):
"""Test AirflowConfigParser.getint"""
@@ -291,16 +289,16 @@ def test_getint(self):
key2 = 1
"""
test_conf = AirflowConfigParser(default_config=test_config)
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowConfigException,
- re.escape(
+ match=re.escape(
'Failed to convert value to int. Please check "key1" key in "invalid" section. '
'Current value: "str".'
),
):
test_conf.getint('invalid', 'key1')
- self.assertTrue(isinstance(test_conf.getint('valid', 'key2'), int))
- self.assertEqual(1, test_conf.getint('valid', 'key2'))
+ assert isinstance(test_conf.getint('valid', 'key2'), int)
+ assert 1 == test_conf.getint('valid', 'key2')
def test_getfloat(self):
"""Test AirflowConfigParser.getfloat"""
@@ -312,16 +310,16 @@ def test_getfloat(self):
key2 = 1.23
"""
test_conf = AirflowConfigParser(default_config=test_config)
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowConfigException,
- re.escape(
+ match=re.escape(
'Failed to convert value to float. Please check "key1" key in "invalid" section. '
'Current value: "str".'
),
):
test_conf.getfloat('invalid', 'key1')
- self.assertTrue(isinstance(test_conf.getfloat('valid', 'key2'), float))
- self.assertEqual(1.23, test_conf.getfloat('valid', 'key2'))
+ assert isinstance(test_conf.getfloat('valid', 'key2'), float)
+ assert 1.23 == test_conf.getfloat('valid', 'key2')
def test_has_option(self):
test_config = '''[test]
@@ -329,9 +327,9 @@ def test_has_option(self):
'''
test_conf = AirflowConfigParser()
test_conf.read_string(test_config)
- self.assertTrue(test_conf.has_option('test', 'key1'))
- self.assertFalse(test_conf.has_option('test', 'key_not_exists'))
- self.assertFalse(test_conf.has_option('section_not_exists', 'key1'))
+ assert test_conf.has_option('test', 'key1')
+ assert not test_conf.has_option('test', 'key_not_exists')
+ assert not test_conf.has_option('section_not_exists', 'key1')
def test_remove_option(self):
test_config = '''[test]
@@ -346,12 +344,12 @@ def test_remove_option(self):
test_conf = AirflowConfigParser(default_config=parameterized_config(test_config_default))
test_conf.read_string(test_config)
- self.assertEqual('hello', test_conf.get('test', 'key1'))
+ assert 'hello' == test_conf.get('test', 'key1')
test_conf.remove_option('test', 'key1', remove_default=False)
- self.assertEqual('awesome', test_conf.get('test', 'key1'))
+ assert 'awesome' == test_conf.get('test', 'key1')
test_conf.remove_option('test', 'key2')
- self.assertFalse(test_conf.has_option('test', 'key2'))
+ assert not test_conf.has_option('test', 'key2')
def test_getsection(self):
test_config = '''
@@ -371,21 +369,14 @@ def test_getsection(self):
test_conf = AirflowConfigParser(default_config=parameterized_config(test_config_default))
test_conf.read_string(test_config)
- self.assertEqual(OrderedDict([('key1', 'hello'), ('key2', 'airflow')]), test_conf.getsection('test'))
- self.assertEqual(
- OrderedDict([('key3', 'value3'), ('testkey', 'testvalue'), ('testpercent', 'with%percent')]),
- test_conf.getsection('testsection'),
- )
+ assert OrderedDict([('key1', 'hello'), ('key2', 'airflow')]) == test_conf.getsection('test')
+ assert OrderedDict(
+ [('key3', 'value3'), ('testkey', 'testvalue'), ('testpercent', 'with%percent')]
+ ) == test_conf.getsection('testsection')
- self.assertEqual(
- OrderedDict([('key', 'value')]),
- test_conf.getsection('new_section'),
- )
+ assert OrderedDict([('key', 'value')]) == test_conf.getsection('new_section')
- self.assertEqual(
- None,
- test_conf.getsection('non_existent_section'),
- )
+ assert test_conf.getsection('non_existent_section') is None
def test_get_section_should_respect_cmd_env_variable(self):
with tempfile.NamedTemporaryFile(delete=False) as cmd_file:
@@ -398,7 +389,7 @@ def test_get_section_should_respect_cmd_env_variable(self):
with mock.patch.dict("os.environ", {"AIRFLOW__WEBSERVER__SECRET_KEY_CMD": cmd_file.name}):
content = conf.getsection("webserver")
os.unlink(cmd_file.name)
- self.assertEqual(content["secret_key"], "difficult_unpredictable_cat_password")
+ assert content["secret_key"] == "difficult_unpredictable_cat_password"
def test_kubernetes_environment_variables_section(self):
test_config = '''
@@ -412,17 +403,16 @@ def test_kubernetes_environment_variables_section(self):
test_conf = AirflowConfigParser(default_config=parameterized_config(test_config_default))
test_conf.read_string(test_config)
- self.assertEqual(
- OrderedDict([('key1', 'hello'), ('AIRFLOW_HOME', '/root/airflow')]),
- test_conf.getsection('kubernetes_environment_variables'),
+ assert OrderedDict([('key1', 'hello'), ('AIRFLOW_HOME', '/root/airflow')]) == test_conf.getsection(
+ 'kubernetes_environment_variables'
)
def test_broker_transport_options(self):
section_dict = conf.getsection("celery_broker_transport_options")
- self.assertTrue(isinstance(section_dict['visibility_timeout'], int))
- self.assertTrue(isinstance(section_dict['_test_only_bool'], bool))
- self.assertTrue(isinstance(section_dict['_test_only_float'], float))
- self.assertTrue(isinstance(section_dict['_test_only_string'], str))
+ assert isinstance(section_dict['visibility_timeout'], int)
+ assert isinstance(section_dict['_test_only_bool'], bool)
+ assert isinstance(section_dict['_test_only_float'], float)
+ assert isinstance(section_dict['_test_only_string'], str)
@conf_vars(
{
@@ -440,12 +430,12 @@ def test_deprecated_options(self):
# Remove it so we are sure we use the right setting
conf.remove_option('celery', 'worker_concurrency')
- with self.assertWarns(DeprecationWarning):
+ with pytest.warns(DeprecationWarning):
with mock.patch.dict('os.environ', AIRFLOW__CELERY__CELERYD_CONCURRENCY="99"):
- self.assertEqual(conf.getint('celery', 'worker_concurrency'), 99)
+ assert conf.getint('celery', 'worker_concurrency') == 99
- with self.assertWarns(DeprecationWarning), conf_vars({('celery', 'celeryd_concurrency'): '99'}):
- self.assertEqual(conf.getint('celery', 'worker_concurrency'), 99)
+ with pytest.warns(DeprecationWarning), conf_vars({('celery', 'celeryd_concurrency'): '99'}):
+ assert conf.getint('celery', 'worker_concurrency') == 99
@conf_vars(
{
@@ -464,12 +454,12 @@ def test_deprecated_options_with_new_section(self):
conf.remove_option('core', 'logging_level')
conf.remove_option('logging', 'logging_level')
- with self.assertWarns(DeprecationWarning):
+ with pytest.warns(DeprecationWarning):
with mock.patch.dict('os.environ', AIRFLOW__CORE__LOGGING_LEVEL="VALUE"):
- self.assertEqual(conf.get('logging', 'logging_level'), "VALUE")
+ assert conf.get('logging', 'logging_level') == "VALUE"
- with self.assertWarns(DeprecationWarning), conf_vars({('core', 'logging_level'): 'VALUE'}):
- self.assertEqual(conf.get('logging', 'logging_level'), "VALUE")
+ with pytest.warns(DeprecationWarning), conf_vars({('core', 'logging_level'): 'VALUE'}):
+ assert conf.get('logging', 'logging_level') == "VALUE"
@conf_vars(
{
@@ -486,11 +476,11 @@ def test_deprecated_options_cmd(self):
conf.remove_option('celery', 'result_backend')
with conf_vars({('celery', 'celery_result_backend_cmd'): '/bin/echo 99'}):
- with self.assertWarns(DeprecationWarning):
+ with pytest.warns(DeprecationWarning):
tmp = None
if 'AIRFLOW__CELERY__RESULT_BACKEND' in os.environ:
tmp = os.environ.pop('AIRFLOW__CELERY__RESULT_BACKEND')
- self.assertEqual(conf.getint('celery', 'result_backend'), 99)
+ assert conf.getint('celery', 'result_backend') == 99
if tmp:
os.environ['AIRFLOW__CELERY__RESULT_BACKEND'] = tmp
@@ -516,14 +506,14 @@ def make_config():
test_conf.validate()
return test_conf
- with self.assertWarns(FutureWarning):
+ with pytest.warns(FutureWarning):
test_conf = make_config()
- self.assertEqual(test_conf.get('core', 'hostname_callable'), 'socket.getfqdn')
+ assert test_conf.get('core', 'hostname_callable') == 'socket.getfqdn'
- with self.assertWarns(FutureWarning):
+ with pytest.warns(FutureWarning):
with unittest.mock.patch.dict('os.environ', AIRFLOW__CORE__HOSTNAME_CALLABLE='socket:getfqdn'):
test_conf = make_config()
- self.assertEqual(test_conf.get('core', 'hostname_callable'), 'socket.getfqdn')
+ assert test_conf.get('core', 'hostname_callable') == 'socket.getfqdn'
with reset_warning_registry():
with warnings.catch_warnings(record=True) as warning:
@@ -532,8 +522,8 @@ def make_config():
AIRFLOW__CORE__HOSTNAME_CALLABLE='CarrierPigeon',
):
test_conf = make_config()
- self.assertEqual(test_conf.get('core', 'hostname_callable'), 'CarrierPigeon')
- self.assertListEqual([], warning)
+ assert test_conf.get('core', 'hostname_callable') == 'CarrierPigeon'
+ assert [] == warning
def test_deprecated_funcs(self):
for func in [
@@ -548,7 +538,7 @@ def test_deprecated_funcs(self):
'set',
]:
with mock.patch(f'airflow.configuration.conf.{func}') as mock_method:
- with self.assertWarns(DeprecationWarning):
+ with pytest.warns(DeprecationWarning):
getattr(configuration, func)()
mock_method.assert_called_once()
@@ -564,48 +554,48 @@ def test_command_from_env(self):
# AIRFLOW__TESTCMDENV__ITSACOMMAND_CMD maps to ('testcmdenv', 'itsacommand') in
# sensitive_config_values and therefore should return 'OK' from the environment variable's
# echo command, and must not return 'NOT OK' from the configuration
- self.assertEqual(test_cmdenv_conf.get('testcmdenv', 'itsacommand'), 'OK')
+ assert test_cmdenv_conf.get('testcmdenv', 'itsacommand') == 'OK'
# AIRFLOW__TESTCMDENV__NOTACOMMAND_CMD maps to no entry in sensitive_config_values and therefore
# the option should return 'OK' from the configuration, and must not return 'NOT OK' from
# the environment variable's echo command
- self.assertEqual(test_cmdenv_conf.get('testcmdenv', 'notacommand'), 'OK')
+ assert test_cmdenv_conf.get('testcmdenv', 'notacommand') == 'OK'
def test_parameterized_config_gen(self):
cfg = parameterized_config(DEFAULT_CONFIG)
# making sure some basic building blocks are present:
- self.assertIn("[core]", cfg)
- self.assertIn("dags_folder", cfg)
- self.assertIn("sql_alchemy_conn", cfg)
- self.assertIn("fernet_key", cfg)
+ assert "[core]" in cfg
+ assert "dags_folder" in cfg
+ assert "sql_alchemy_conn" in cfg
+ assert "fernet_key" in cfg
# making sure replacement actually happened
- self.assertNotIn("{AIRFLOW_HOME}", cfg)
- self.assertNotIn("{FERNET_KEY}", cfg)
+ assert "{AIRFLOW_HOME}" not in cfg
+ assert "{FERNET_KEY}" not in cfg
def test_config_use_original_when_original_and_fallback_are_present(self):
- self.assertTrue(conf.has_option("core", "FERNET_KEY"))
- self.assertFalse(conf.has_option("core", "FERNET_KEY_CMD"))
+ assert conf.has_option("core", "FERNET_KEY")
+ assert not conf.has_option("core", "FERNET_KEY_CMD")
fernet_key = conf.get('core', 'FERNET_KEY')
with conf_vars({('core', 'FERNET_KEY_CMD'): 'printf HELLO'}):
fallback_fernet_key = conf.get("core", "FERNET_KEY")
- self.assertEqual(fernet_key, fallback_fernet_key)
+ assert fernet_key == fallback_fernet_key
def test_config_throw_error_when_original_and_fallback_is_absent(self):
- self.assertTrue(conf.has_option("core", "FERNET_KEY"))
- self.assertFalse(conf.has_option("core", "FERNET_KEY_CMD"))
+ assert conf.has_option("core", "FERNET_KEY")
+ assert not conf.has_option("core", "FERNET_KEY_CMD")
with conf_vars({('core', 'fernet_key'): None}):
- with self.assertRaises(AirflowConfigException) as cm:
+ with pytest.raises(AirflowConfigException) as ctx:
conf.get("core", "FERNET_KEY")
- exception = str(cm.exception)
+ exception = str(ctx.value)
message = "section/key [core/fernet_key] not found in config"
- self.assertEqual(message, exception)
+ assert message == exception
def test_config_override_original_when_non_empty_envvar_is_provided(self):
key = "AIRFLOW__CORE__FERNET_KEY"
@@ -614,7 +604,7 @@ def test_config_override_original_when_non_empty_envvar_is_provided(self):
with mock.patch.dict('os.environ', {key: value}):
fernet_key = conf.get('core', 'FERNET_KEY')
- self.assertEqual(value, fernet_key)
+ assert value == fernet_key
def test_config_override_original_when_empty_envvar_is_provided(self):
key = "AIRFLOW__CORE__FERNET_KEY"
@@ -623,40 +613,41 @@ def test_config_override_original_when_empty_envvar_is_provided(self):
with mock.patch.dict('os.environ', {key: value}):
fernet_key = conf.get('core', 'FERNET_KEY')
- self.assertEqual(value, fernet_key)
+ assert value == fernet_key
@mock.patch.dict("os.environ", {"AIRFLOW__CORE__DAGS_FOLDER": "/tmp/test_folder"})
def test_write_should_respect_env_variable(self):
with io.StringIO() as string_file:
conf.write(string_file)
content = string_file.getvalue()
- self.assertIn("dags_folder = /tmp/test_folder", content)
+ assert "dags_folder = /tmp/test_folder" in content
def test_run_command(self):
write = r'sys.stdout.buffer.write("\u1000foo".encode("utf8"))'
cmd = f'import sys; {write}; sys.stdout.flush()'
- self.assertEqual(run_command(f"python -c '{cmd}'"), '\u1000foo')
+ assert run_command(f"python -c '{cmd}'") == '\u1000foo'
- self.assertEqual(run_command('echo "foo bar"'), 'foo bar\n')
- self.assertRaises(AirflowConfigException, run_command, 'bash -c "exit 1"')
+ assert run_command('echo "foo bar"') == 'foo bar\n'
+ with pytest.raises(AirflowConfigException):
+ run_command('bash -c "exit 1"')
def test_confirm_unittest_mod(self):
- self.assertTrue(conf.get('core', 'unit_test_mode'))
+ assert conf.get('core', 'unit_test_mode')
@conf_vars({("core", "store_serialized_dags"): "True"})
def test_store_dag_code_default_config(self):
store_serialized_dags = conf.getboolean('core', 'store_serialized_dags', fallback=False)
store_dag_code = conf.getboolean("core", "store_dag_code", fallback=store_serialized_dags)
- self.assertFalse(conf.has_option("core", "store_dag_code"))
- self.assertTrue(store_serialized_dags)
- self.assertTrue(store_dag_code)
+ assert not conf.has_option("core", "store_dag_code")
+ assert store_serialized_dags
+ assert store_dag_code
@conf_vars({("core", "store_serialized_dags"): "True", ("core", "store_dag_code"): "False"})
def test_store_dag_code_config_when_set(self):
store_serialized_dags = conf.getboolean('core', 'store_serialized_dags', fallback=False)
store_dag_code = conf.getboolean("core", "store_dag_code", fallback=store_serialized_dags)
- self.assertTrue(conf.has_option("core", "store_dag_code"))
- self.assertTrue(store_serialized_dags)
- self.assertFalse(store_dag_code)
+ assert conf.has_option("core", "store_dag_code")
+ assert store_serialized_dags
+ assert not store_dag_code
diff --git a/tests/core/test_core.py b/tests/core/test_core.py
index ba05f0213f4bf..5073aa473d582 100644
--- a/tests/core/test_core.py
+++ b/tests/core/test_core.py
@@ -23,6 +23,8 @@
from datetime import timedelta
from time import sleep
+import pytest
+
from airflow import settings
from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.hooks.base import BaseHook
@@ -124,31 +126,28 @@ def test_illegal_args(self):
"""
msg = 'Invalid arguments were passed to BashOperator (task_id: test_illegal_args).'
with conf_vars({('operators', 'allow_illegal_arguments'): 'True'}):
- with self.assertWarns(PendingDeprecationWarning) as warning:
+ with pytest.warns(PendingDeprecationWarning) as warnings:
BashOperator(
task_id='test_illegal_args',
bash_command='echo success',
dag=self.dag,
illegal_argument_1234='hello?',
)
- assert any(msg in str(w) for w in warning.warnings)
+ assert any(msg in str(w) for w in warnings)
def test_illegal_args_forbidden(self):
"""
Tests that operators raise exceptions on illegal arguments when
illegal arguments are not allowed.
"""
- with self.assertRaises(AirflowException) as ctx:
+ with pytest.raises(AirflowException) as ctx:
BashOperator(
task_id='test_illegal_args',
bash_command='echo success',
dag=self.dag,
illegal_argument_1234='hello?',
)
- self.assertIn(
- 'Invalid arguments were passed to BashOperator (task_id: test_illegal_args).',
- str(ctx.exception),
- )
+ assert 'Invalid arguments were passed to BashOperator (task_id: test_illegal_args).' in str(ctx.value)
def test_bash_operator(self):
op = BashOperator(task_id='test_bash_operator', bash_command="echo success", dag=self.dag)
@@ -176,7 +175,8 @@ def test_bash_operator_kill(self):
bash_command="/bin/bash -c 'sleep %s'" % sleep_time,
dag=self.dag,
)
- self.assertRaises(AirflowTaskTimeout, op.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ with pytest.raises(AirflowTaskTimeout):
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
sleep(2)
pid = -1
for proc in psutil.process_iter():
@@ -201,10 +201,9 @@ def check_failure(context, test_case=self):
dag=self.dag,
on_failure_callback=check_failure,
)
- self.assertRaises(
- AirflowException, op.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
- )
- self.assertTrue(data['called'])
+ with pytest.raises(AirflowException):
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+ assert data['called']
def test_dryrun(self):
op = BashOperator(task_id='test_dryrun', bash_command="echo success", dag=self.dag)
@@ -226,9 +225,8 @@ def test_timeout(self):
python_callable=lambda: sleep(5),
dag=self.dag,
)
- self.assertRaises(
- AirflowTaskTimeout, op.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
- )
+ with pytest.raises(AirflowTaskTimeout):
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_python_op(self):
def test_py_op(templates_dict, ds, **kwargs):
@@ -243,7 +241,7 @@ def test_py_op(templates_dict, ds, **kwargs):
def test_complex_template(self):
def verify_templated_field(context):
- self.assertEqual(context['ti'].task.some_templated_field['bar'][1], context['ds'])
+ assert context['ti'].task.some_templated_field['bar'][1] == context['ds']
op = OperatorSubclass(
task_id='test_complex_template',
@@ -282,26 +280,26 @@ def test_task_get_template(self):
context = ti.get_template_context()
# DEFAULT DATE is 2015-01-01
- self.assertEqual(context['ds'], '2015-01-01')
- self.assertEqual(context['ds_nodash'], '20150101')
+ assert context['ds'] == '2015-01-01'
+ assert context['ds_nodash'] == '20150101'
# next_ds is 2015-01-02 as the dag interval is daily
- self.assertEqual(context['next_ds'], '2015-01-02')
- self.assertEqual(context['next_ds_nodash'], '20150102')
+ assert context['next_ds'] == '2015-01-02'
+ assert context['next_ds_nodash'] == '20150102'
# prev_ds is 2014-12-31 as the dag interval is daily
- self.assertEqual(context['prev_ds'], '2014-12-31')
- self.assertEqual(context['prev_ds_nodash'], '20141231')
+ assert context['prev_ds'] == '2014-12-31'
+ assert context['prev_ds_nodash'] == '20141231'
- self.assertEqual(context['ts'], '2015-01-01T00:00:00+00:00')
- self.assertEqual(context['ts_nodash'], '20150101T000000')
- self.assertEqual(context['ts_nodash_with_tz'], '20150101T000000+0000')
+ assert context['ts'] == '2015-01-01T00:00:00+00:00'
+ assert context['ts_nodash'] == '20150101T000000'
+ assert context['ts_nodash_with_tz'] == '20150101T000000+0000'
- self.assertEqual(context['yesterday_ds'], '2014-12-31')
- self.assertEqual(context['yesterday_ds_nodash'], '20141231')
+ assert context['yesterday_ds'] == '2014-12-31'
+ assert context['yesterday_ds_nodash'] == '20141231'
- self.assertEqual(context['tomorrow_ds'], '2015-01-02')
- self.assertEqual(context['tomorrow_ds_nodash'], '20150102')
+ assert context['tomorrow_ds'] == '2015-01-02'
+ assert context['tomorrow_ds_nodash'] == '20150102'
def test_local_task_job(self):
TI = TaskInstance
@@ -319,7 +317,7 @@ def test_raw_job(self):
ti.run(ignore_ti_state=True)
def test_bad_trigger_rule(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
DummyOperator(task_id='test_bad_trigger', trigger_rule="non_existent", dag=self.dag)
def test_terminate_task(self):
@@ -341,7 +339,7 @@ def test_terminate_task(self):
session = settings.Session()
ti.refresh_from_db(session=session)
# making sure it's actually running
- self.assertEqual(State.RUNNING, ti.state)
+ assert State.RUNNING == ti.state
ti = (
session.query(TI)
.filter_by(dag_id=task.dag_id, task_id=task.task_id, execution_date=DEFAULT_DATE)
@@ -356,7 +354,7 @@ def test_terminate_task(self):
# making sure that the task ended up as failed
ti.refresh_from_db(session=session)
- self.assertEqual(State.FAILED, ti.state)
+ assert State.FAILED == ti.state
session.close()
def test_task_fail_duration(self):
@@ -390,9 +388,9 @@ def test_task_fail_duration(self):
.all()
)
- self.assertEqual(0, len(op1_fails))
- self.assertEqual(1, len(op2_fails))
- self.assertGreaterEqual(sum([f.duration for f in op2_fails]), 3)
+ assert 0 == len(op1_fails)
+ assert 1 == len(op2_fails)
+ assert sum([f.duration for f in op2_fails]) >= 3
def test_externally_triggered_dagrun(self):
TI = TaskInstance
@@ -418,8 +416,8 @@ def test_externally_triggered_dagrun(self):
context = ti.get_template_context()
# next_ds/prev_ds should be the execution date for manually triggered runs
- self.assertEqual(context['next_ds'], execution_ds)
- self.assertEqual(context['next_ds_nodash'], execution_ds_nodash)
+ assert context['next_ds'] == execution_ds
+ assert context['next_ds_nodash'] == execution_ds_nodash
- self.assertEqual(context['prev_ds'], execution_ds)
- self.assertEqual(context['prev_ds_nodash'], execution_ds_nodash)
+ assert context['prev_ds'] == execution_ds
+ assert context['prev_ds_nodash'] == execution_ds_nodash
diff --git a/tests/core/test_core_to_contrib.py b/tests/core/test_core_to_contrib.py
index e91e69d987525..4e6d64b7d81f7 100644
--- a/tests/core/test_core_to_contrib.py
+++ b/tests/core/test_core_to_contrib.py
@@ -19,9 +19,9 @@
import importlib
import sys
from inspect import isabstract
-from typing import Any
from unittest import TestCase, mock
+import pytest
from parameterized import parameterized
from airflow.models.baseoperator import BaseOperator
@@ -30,20 +30,20 @@
class TestMovingCoreToContrib(TestCase):
@staticmethod
- def assert_warning(msg: str, warning: Any):
+ def assert_warning(msg: str, warnings):
error = f"Text '{msg}' not in warnings"
- assert any(msg in str(w) for w in warning.warnings), error
+ assert any(msg in str(w) for w in warnings), error
def assert_is_subclass(self, clazz, other):
- self.assertTrue(issubclass(clazz, other), f"{clazz} is not subclass of {other}")
+ assert issubclass(clazz, other), f"{clazz} is not subclass of {other}"
def assert_proper_import(self, old_resource, new_resource):
new_path, _, _ = new_resource.rpartition(".")
old_path, _, _ = old_resource.rpartition(".")
- with self.assertWarns(DeprecationWarning) as warning_msg:
+ with pytest.warns(DeprecationWarning) as warnings:
# Reload to see deprecation warning each time
importlib.reload(importlib.import_module(old_path))
- self.assert_warning(new_path, warning_msg)
+ self.assert_warning(new_path, warnings)
def skip_test_with_mssql_in_py38(self, path_a="", path_b=""):
py_38 = sys.version_info >= (3, 8)
@@ -75,15 +75,15 @@ def test_is_class_deprecated(self, new_module, old_module):
self.skip_test_with_mssql_in_py38(new_module, old_module)
deprecation_warning_msg = "This class is deprecated."
old_module_class = self.get_class_from_path(old_module)
- with self.assertWarnsRegex(DeprecationWarning, deprecation_warning_msg) as wrn:
+ with pytest.warns(DeprecationWarning, match=deprecation_warning_msg) as warnings:
with mock.patch(f"{new_module}.__init__") as init_mock:
init_mock.return_value = None
klass = old_module_class()
if isinstance(klass, BaseOperator):
# In case of operators we are validating that proper stacklevel
# is used (=3 or =4 if @apply_defaults)
- assert len(wrn.warnings) == 1
- assert wrn.warnings[0].filename == __file__
+ assert len(warnings) == 1
+ assert warnings[0].filename == __file__
init_mock.assert_called_once_with()
@parameterized.expand(ALL)
diff --git a/tests/core/test_impersonation_tests.py b/tests/core/test_impersonation_tests.py
index 686142ebcb8e9..bc1fa9b04e5c3 100644
--- a/tests/core/test_impersonation_tests.py
+++ b/tests/core/test_impersonation_tests.py
@@ -140,7 +140,7 @@ def run_backfill(self, dag_id, task_id):
ti = models.TaskInstance(task=dag.get_task(task_id), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
def test_impersonation(self):
"""
@@ -203,7 +203,7 @@ def run_backfill(self, dag_id, task_id):
ti = models.TaskInstance(task=dag.get_task(task_id), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
@mock_custom_module_path(TEST_UTILS_FOLDER)
def test_impersonation_custom(self):
diff --git a/tests/core/test_logging_config.py b/tests/core/test_logging_config.py
index 9c6983319c1a0..635af8dba0c17 100644
--- a/tests/core/test_logging_config.py
+++ b/tests/core/test_logging_config.py
@@ -25,6 +25,8 @@
import unittest
from unittest.mock import patch
+import pytest
+
from airflow.configuration import conf
from tests.test_utils.config import conf_vars
@@ -192,7 +194,7 @@ def test_loading_invalid_local_settings(self):
with settings_context(SETTINGS_FILE_INVALID):
with patch.object(log, 'error') as mock_info:
# Load config
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
configure_logging()
mock_info.assert_called_once_with(
@@ -230,7 +232,7 @@ def test_loading_no_local_settings(self):
with settings_context(SETTINGS_FILE_EMPTY):
from airflow.logging_config import configure_logging
- with self.assertRaises(ImportError):
+ with pytest.raises(ImportError):
configure_logging()
# When the key is not available in the configuration
@@ -254,9 +256,9 @@ def test_1_9_config(self):
from airflow.logging_config import configure_logging
with conf_vars({('logging', 'task_log_reader'): 'file.task'}):
- with self.assertWarnsRegex(DeprecationWarning, r'file.task'):
+ with pytest.warns(DeprecationWarning, match=r'file.task'):
configure_logging()
- self.assertEqual(conf.get('logging', 'task_log_reader'), 'task')
+ assert conf.get('logging', 'task_log_reader') == 'task'
def test_loading_remote_logging_with_wasb_handler(self):
"""Test if logging can be configured successfully for Azure Blob Storage"""
@@ -275,4 +277,4 @@ def test_loading_remote_logging_with_wasb_handler(self):
configure_logging()
logger = logging.getLogger('airflow.task')
- self.assertIsInstance(logger.handlers[0], WasbTaskHandler)
+ assert isinstance(logger.handlers[0], WasbTaskHandler)
diff --git a/tests/core/test_providers_manager.py b/tests/core/test_providers_manager.py
index ec05da2ff9c87..7d80c58265d33 100644
--- a/tests/core/test_providers_manager.py
+++ b/tests/core/test_providers_manager.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import re
import unittest
from airflow.providers_manager import ProvidersManager
@@ -201,26 +202,26 @@ def test_providers_are_loaded(self):
for provider in provider_list:
package_name = provider_manager.providers[provider][1]['package-name']
version = provider_manager.providers[provider][0]
- self.assertRegex(version, r'[0-9]*\.[0-9]*\.[0-9]*.*')
- self.assertEqual(package_name, provider)
- self.assertEqual(ALL_PROVIDERS, provider_list)
+ assert re.search(r'[0-9]*\.[0-9]*\.[0-9]*.*', version)
+ assert package_name == provider
+ assert ALL_PROVIDERS == provider_list
def test_hooks(self):
provider_manager = ProvidersManager()
connections_list = list(provider_manager.hooks.keys())
- self.assertEqual(CONNECTIONS_LIST, connections_list)
+ assert CONNECTIONS_LIST == connections_list
def test_connection_form_widgets(self):
provider_manager = ProvidersManager()
connections_form_widgets = list(provider_manager.connection_form_widgets.keys())
- self.assertEqual(CONNECTION_FORM_WIDGETS, connections_form_widgets)
+ assert CONNECTION_FORM_WIDGETS == connections_form_widgets
def test_field_behaviours(self):
provider_manager = ProvidersManager()
connections_with_field_behaviours = list(provider_manager.field_behaviours.keys())
- self.assertEqual(CONNECTIONS_WITH_FIELD_BEHAVIOURS, connections_with_field_behaviours)
+ assert CONNECTIONS_WITH_FIELD_BEHAVIOURS == connections_with_field_behaviours
def test_extra_links(self):
provider_manager = ProvidersManager()
extra_link_class_names = list(provider_manager.extra_links_class_names)
- self.assertEqual(EXTRA_LINKS, extra_link_class_names)
+ assert EXTRA_LINKS == extra_link_class_names
diff --git a/tests/core/test_sentry.py b/tests/core/test_sentry.py
index bfff5843fdfc4..44a39d9126a68 100644
--- a/tests/core/test_sentry.py
+++ b/tests/core/test_sentry.py
@@ -91,7 +91,7 @@ def test_add_tagging(self):
self.sentry.add_tagging(task_instance=self.ti)
with configure_scope() as scope:
for key, value in scope._tags.items():
- self.assertEqual(TEST_SCOPE[key], value)
+ assert TEST_SCOPE[key] == value
@freeze_time(CRUMB_DATE.isoformat())
def test_add_breadcrumbs(self):
@@ -103,4 +103,4 @@ def test_add_breadcrumbs(self):
with configure_scope() as scope:
test_crumb = scope._breadcrumbs.pop()
- self.assertEqual(CRUMB, test_crumb)
+ assert CRUMB == test_crumb
diff --git a/tests/core/test_settings.py b/tests/core/test_settings.py
index 0915ec3b6a8be..2c7bb9f2a7e86 100644
--- a/tests/core/test_settings.py
+++ b/tests/core/test_settings.py
@@ -22,6 +22,8 @@
import unittest
from unittest.mock import MagicMock, call
+import pytest
+
from airflow.exceptions import AirflowClusterPolicyViolation
from tests.test_utils.config import conf_vars
@@ -112,7 +114,7 @@ def test_import_with_dunder_all_not_specified(self):
settings.import_local_settings()
- with self.assertRaises(AttributeError):
+ with pytest.raises(AttributeError):
settings.not_policy() # pylint: disable=no-member
def test_import_with_dunder_all(self):
@@ -179,7 +181,7 @@ def test_custom_policy(self):
task_instance = MagicMock()
task_instance.owner = 'airflow'
- with self.assertRaises(AirflowClusterPolicyViolation):
+ with pytest.raises(AirflowClusterPolicyViolation):
settings.task_must_have_owners(task_instance) # pylint: disable=no-member
@@ -190,10 +192,10 @@ class TestUpdatedConfigNames(unittest.TestCase):
def test_updates_deprecated_session_timeout_config_val_when_new_config_val_is_default(self):
from airflow import settings
- with self.assertWarns(DeprecationWarning):
+ with pytest.warns(DeprecationWarning):
session_lifetime_config = settings.get_session_lifetime_config()
minutes_in_five_days = 5 * 24 * 60
- self.assertEqual(session_lifetime_config, minutes_in_five_days)
+ assert session_lifetime_config == minutes_in_five_days
@conf_vars(
{("webserver", "session_lifetime_days"): '5', ("webserver", "session_lifetime_minutes"): '43201'}
@@ -202,7 +204,7 @@ def test_uses_updated_session_timeout_config_when_val_is_not_default(self):
from airflow import settings
session_lifetime_config = settings.get_session_lifetime_config()
- self.assertEqual(session_lifetime_config, 43201)
+ assert session_lifetime_config == 43201
@conf_vars({("webserver", "session_lifetime_days"): ''})
def test_uses_updated_session_timeout_config_by_default(self):
@@ -210,4 +212,4 @@ def test_uses_updated_session_timeout_config_by_default(self):
session_lifetime_config = settings.get_session_lifetime_config()
default_timeout_minutes = 30 * 24 * 60
- self.assertEqual(session_lifetime_config, default_timeout_minutes)
+ assert session_lifetime_config == default_timeout_minutes
diff --git a/tests/core/test_sqlalchemy_config.py b/tests/core/test_sqlalchemy_config.py
index a7cd08b727bd7..c4c909bd203f1 100644
--- a/tests/core/test_sqlalchemy_config.py
+++ b/tests/core/test_sqlalchemy_config.py
@@ -19,6 +19,7 @@
import unittest
from unittest.mock import patch
+import pytest
from sqlalchemy.pool import NullPool
from airflow import settings
@@ -92,6 +93,6 @@ def test_sql_alchemy_invalid_connect_args(
('core', 'sql_alchemy_connect_args'): 'does.not.exist',
('core', 'sql_alchemy_pool_enabled'): 'False',
}
- with self.assertRaises(AirflowConfigException):
+ with pytest.raises(AirflowConfigException):
with conf_vars(config):
settings.configure_orm()
diff --git a/tests/core/test_stats.py b/tests/core/test_stats.py
index eebdde2973097..428192b8b0d23 100644
--- a/tests/core/test_stats.py
+++ b/tests/core/test_stats.py
@@ -21,6 +21,7 @@
from unittest import mock
from unittest.mock import Mock
+import pytest
import statsd
import airflow
@@ -107,7 +108,7 @@ def test_does_not_send_stats_using_dogstatsd(self, mock_dogstatsd):
)
def test_load_custom_statsd_client(self):
importlib.reload(airflow.stats)
- self.assertEqual('CustomStatsd', type(airflow.stats.Stats.statsd).__name__)
+ assert 'CustomStatsd' == type(airflow.stats.Stats.statsd).__name__ # noqa: E721
@conf_vars(
{
@@ -127,9 +128,9 @@ def test_does_use_custom_statsd_client(self):
}
)
def test_load_invalid_custom_stats_client(self):
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowConfigException,
- re.escape(
+ match=re.escape(
'Your custom Statsd client must extend the statsd.'
'StatsClient in order to ensure backwards compatibility.'
),
diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py
index 06fe03067c84f..22518783ed671 100644
--- a/tests/executors/test_base_executor.py
+++ b/tests/executors/test_base_executor.py
@@ -41,9 +41,9 @@ def test_get_event_buffer(self):
executor.event_buffer[key2] = state, None
executor.event_buffer[key3] = state, None
- self.assertEqual(len(executor.get_event_buffer(("my_dag1",))), 1)
- self.assertEqual(len(executor.get_event_buffer()), 2)
- self.assertEqual(len(executor.event_buffer), 0)
+ assert len(executor.get_event_buffer(("my_dag1",))) == 1
+ assert len(executor.get_event_buffer()) == 2
+ assert len(executor.event_buffer) == 0
@mock.patch('airflow.executors.base_executor.BaseExecutor.sync')
@mock.patch('airflow.executors.base_executor.BaseExecutor.trigger_tasks')
@@ -71,4 +71,4 @@ def test_try_adopt_task_instances(self):
key2 = TaskInstance(task=task_2, execution_date=date)
key3 = TaskInstance(task=task_3, execution_date=date)
tis = [key1, key2, key3]
- self.assertEqual(BaseExecutor().try_adopt_task_instances(tis), tis)
+ assert BaseExecutor().try_adopt_task_instances(tis) == tis
diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py
index 38aa583f486a0..944fa49c4b932 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -116,7 +116,7 @@ def fake_execute_command(command):
with _prepare_app(broker_url, execute=fake_execute_command) as app:
executor = celery_executor.CeleryExecutor()
- self.assertEqual(executor.tasks, {})
+ assert executor.tasks == {}
executor.start()
with start_worker(app=app, logfile=sys.stdout, loglevel='info'):
@@ -146,32 +146,25 @@ def fake_execute_command(command):
executor._process_tasks(task_tuples_to_send)
- self.assertEqual(
- list(executor.tasks.keys()),
- [
- ('success', 'fake_simple_ti', execute_date, 0),
- ('fail', 'fake_simple_ti', execute_date, 0),
- ],
- )
- self.assertEqual(
- executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0], State.QUEUED
- )
- self.assertEqual(
- executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0], State.QUEUED
+ assert list(executor.tasks.keys()) == [
+ ('success', 'fake_simple_ti', execute_date, 0),
+ ('fail', 'fake_simple_ti', execute_date, 0),
+ ]
+ assert (
+ executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0] == State.QUEUED
)
+ assert executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0] == State.QUEUED
executor.end(synchronous=True)
- self.assertEqual(
- executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0], State.SUCCESS
- )
- self.assertEqual(executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0], State.FAILED)
+ assert executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0] == State.SUCCESS
+ assert executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0] == State.FAILED
- self.assertNotIn('success', executor.tasks)
- self.assertNotIn('fail', executor.tasks)
+ assert 'success' not in executor.tasks
+ assert 'fail' not in executor.tasks
- self.assertEqual(executor.queued_tasks, {})
- self.assertEqual(timedelta(0, 600), executor.task_adoption_timeout)
+ assert executor.queued_tasks == {}
+ assert timedelta(0, 600) == executor.task_adoption_timeout
@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
@@ -198,8 +191,8 @@ def fake_execute_command():
executor.queued_tasks[key] = value_tuple
executor.task_publish_retries[key] = 1
executor.heartbeat()
- self.assertEqual(0, len(executor.queued_tasks), "Task should no longer be queued")
- self.assertEqual(executor.event_buffer[('fail', 'fake_simple_ti', when, 0)][0], State.FAILED)
+ assert 0 == len(executor.queued_tasks), "Task should no longer be queued"
+ assert executor.event_buffer[('fail', 'fake_simple_ti', when, 0)][0] == State.FAILED
@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
@@ -216,8 +209,8 @@ def fake_execute_command(command):
# fake_execute_command takes no arguments while execute_command takes 1,
# which will cause TypeError when calling task.apply_async()
executor = celery_executor.CeleryExecutor()
- self.assertEqual(executor.task_publish_retries, {})
- self.assertEqual(executor.task_publish_max_retries, 3, msg="Assert Default Max Retries is 3")
+ assert executor.task_publish_retries == {}
+ assert executor.task_publish_max_retries == 3, "Assert Default Max Retries is 3"
task = BashOperator(
task_id="test", bash_command="true", dag=DAG(dag_id='id'), start_date=datetime.now()
@@ -234,39 +227,36 @@ def fake_execute_command(command):
# Test that when heartbeat is called again, task is published again to Celery Queue
executor.heartbeat()
- self.assertEqual(dict(executor.task_publish_retries), {key: 2})
- self.assertEqual(1, len(executor.queued_tasks), "Task should remain in queue")
- self.assertEqual(executor.event_buffer, {})
- self.assertIn(
+ assert dict(executor.task_publish_retries) == {key: 2}
+ assert 1 == len(executor.queued_tasks), "Task should remain in queue"
+ assert executor.event_buffer == {}
+ assert (
"INFO:airflow.executors.celery_executor.CeleryExecutor:"
- f"[Try 1 of 3] Task Timeout Error for Task: ({key}).",
- cm.output,
+ f"[Try 1 of 3] Task Timeout Error for Task: ({key})." in cm.output
)
executor.heartbeat()
- self.assertEqual(dict(executor.task_publish_retries), {key: 3})
- self.assertEqual(1, len(executor.queued_tasks), "Task should remain in queue")
- self.assertEqual(executor.event_buffer, {})
- self.assertIn(
+ assert dict(executor.task_publish_retries) == {key: 3}
+ assert 1 == len(executor.queued_tasks), "Task should remain in queue"
+ assert executor.event_buffer == {}
+ assert (
"INFO:airflow.executors.celery_executor.CeleryExecutor:"
- f"[Try 2 of 3] Task Timeout Error for Task: ({key}).",
- cm.output,
+ f"[Try 2 of 3] Task Timeout Error for Task: ({key})." in cm.output
)
executor.heartbeat()
- self.assertEqual(dict(executor.task_publish_retries), {key: 4})
- self.assertEqual(1, len(executor.queued_tasks), "Task should remain in queue")
- self.assertEqual(executor.event_buffer, {})
- self.assertIn(
+ assert dict(executor.task_publish_retries) == {key: 4}
+ assert 1 == len(executor.queued_tasks), "Task should remain in queue"
+ assert executor.event_buffer == {}
+ assert (
"INFO:airflow.executors.celery_executor.CeleryExecutor:"
- f"[Try 3 of 3] Task Timeout Error for Task: ({key}).",
- cm.output,
+ f"[Try 3 of 3] Task Timeout Error for Task: ({key})." in cm.output
)
executor.heartbeat()
- self.assertEqual(dict(executor.task_publish_retries), {})
- self.assertEqual(0, len(executor.queued_tasks), "Task should no longer be in queue")
- self.assertEqual(executor.event_buffer[('fail', 'fake_simple_ti', when, 0)][0], State.FAILED)
+ assert dict(executor.task_publish_retries) == {}
+ assert 0 == len(executor.queued_tasks), "Task should no longer be in queue"
+ assert executor.event_buffer[('fail', 'fake_simple_ti', when, 0)][0] == State.FAILED
@pytest.mark.quarantined
@pytest.mark.backend("mysql", "postgres")
@@ -277,8 +267,8 @@ def test_exception_propagation(self):
executor.tasks = {'key': FakeCeleryResult()}
executor.bulk_state_fetcher._get_many_using_multiprocessing(executor.tasks.values())
- self.assertTrue(any(celery_executor.CELERY_FETCH_ERR_MSG_HEADER in line for line in cm.output))
- self.assertTrue(any("Exception" in line for line in cm.output))
+ assert any(celery_executor.CELERY_FETCH_ERR_MSG_HEADER in line for line in cm.output)
+ assert any("Exception" in line for line in cm.output)
@mock.patch('airflow.executors.celery_executor.CeleryExecutor.sync')
@mock.patch('airflow.executors.celery_executor.CeleryExecutor.trigger_tasks')
@@ -323,7 +313,7 @@ def test_try_adopt_task_instances_none(self):
tis = [key1]
executor = celery_executor.CeleryExecutor()
- self.assertEqual(executor.try_adopt_task_instances(tis), tis)
+ assert executor.try_adopt_task_instances(tis) == tis
@pytest.mark.backend("mysql", "postgres")
def test_try_adopt_task_instances(self):
@@ -346,24 +336,21 @@ def test_try_adopt_task_instances(self):
tis = [ti1, ti2]
executor = celery_executor.CeleryExecutor()
- self.assertEqual(executor.running, set())
- self.assertEqual(executor.adopted_task_timeouts, {})
- self.assertEqual(executor.tasks, {})
+ assert executor.running == set()
+ assert executor.adopted_task_timeouts == {}
+ assert executor.tasks == {}
not_adopted_tis = executor.try_adopt_task_instances(tis)
key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, exec_date, try_number)
key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, exec_date, try_number)
- self.assertEqual(executor.running, {key_1, key_2})
- self.assertEqual(
- dict(executor.adopted_task_timeouts),
- {
- key_1: queued_dttm + executor.task_adoption_timeout,
- key_2: queued_dttm + executor.task_adoption_timeout,
- },
- )
- self.assertEqual(executor.tasks, {key_1: AsyncResult("231"), key_2: AsyncResult("232")})
- self.assertEqual(not_adopted_tis, [])
+ assert executor.running == {key_1, key_2}
+ assert dict(executor.adopted_task_timeouts) == {
+ key_1: queued_dttm + executor.task_adoption_timeout,
+ key_2: queued_dttm + executor.task_adoption_timeout,
+ }
+ assert executor.tasks == {key_1: AsyncResult("231"), key_2: AsyncResult("232")}
+ assert not_adopted_tis == []
@pytest.mark.backend("mysql", "postgres")
def test_check_for_stalled_adopted_tasks(self):
@@ -387,9 +374,9 @@ def test_check_for_stalled_adopted_tasks(self):
}
executor.tasks = {key_1: AsyncResult("231"), key_2: AsyncResult("232")}
executor.sync()
- self.assertEqual(executor.event_buffer, {key_1: (State.FAILED, None), key_2: (State.FAILED, None)})
- self.assertEqual(executor.tasks, {})
- self.assertEqual(executor.adopted_task_timeouts, {})
+ assert executor.event_buffer == {key_1: (State.FAILED, None), key_2: (State.FAILED, None)}
+ assert executor.tasks == {}
+ assert executor.adopted_task_timeouts == {}
def test_operation_timeout_config():
@@ -438,10 +425,10 @@ def test_should_support_kv_backend(self, mock_mget):
# Assert called - ignore order
mget_args, _ = mock_mget.call_args
- self.assertEqual(set(mget_args[0]), {b'celery-task-meta-456', b'celery-task-meta-123'})
+ assert set(mget_args[0]) == {b'celery-task-meta-456', b'celery-task-meta-123'}
mock_mget.assert_called_once_with(mock.ANY)
- self.assertEqual(result, {'123': ('SUCCESS', None), '456': ("PENDING", None)})
+ assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
@mock.patch("celery.backends.database.DatabaseBackend.ResultSession")
@pytest.mark.integration("redis")
@@ -465,7 +452,7 @@ def test_should_support_db_backend(self, mock_session):
]
)
- self.assertEqual(result, {'123': ('SUCCESS', None), '456': ("PENDING", None)})
+ assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
@@ -483,4 +470,4 @@ def test_should_support_base_backend(self):
]
)
- self.assertEqual(result, {'123': ('SUCCESS', None), '456': ("PENDING", None)})
+ assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
diff --git a/tests/executors/test_dask_executor.py b/tests/executors/test_dask_executor.py
index 09a22f7bedbc3..d23f94b6aa2be 100644
--- a/tests/executors/test_dask_executor.py
+++ b/tests/executors/test_dask_executor.py
@@ -62,12 +62,12 @@ def assert_tasks_on_executor(self, executor):
)
# both tasks should have finished
- self.assertTrue(success_future.done())
- self.assertTrue(fail_future.done())
+ assert success_future.done()
+ assert fail_future.done()
# check task exceptions
- self.assertTrue(success_future.exception() is None)
- self.assertTrue(fail_future.exception() is not None)
+ assert success_future.exception() is None
+ assert fail_future.exception() is not None
class TestDaskExecutor(TestBaseDask):
diff --git a/tests/executors/test_executor_loader.py b/tests/executors/test_executor_loader.py
index 63ef8dd71961f..43dbbbd396d13 100644
--- a/tests/executors/test_executor_loader.py
+++ b/tests/executors/test_executor_loader.py
@@ -56,19 +56,19 @@ def tearDown(self) -> None:
def test_should_support_executor_from_core(self, executor_name):
with conf_vars({("core", "executor"): executor_name}):
executor = ExecutorLoader.get_default_executor()
- self.assertIsNotNone(executor)
- self.assertEqual(executor_name, executor.__class__.__name__)
+ assert executor is not None
+ assert executor_name == executor.__class__.__name__
@mock.patch("airflow.plugins_manager.plugins", [FakePlugin()])
@mock.patch("airflow.plugins_manager.executors_modules", None)
def test_should_support_plugins(self):
with conf_vars({("core", "executor"): f"{TEST_PLUGIN_NAME}.FakeExecutor"}):
executor = ExecutorLoader.get_default_executor()
- self.assertIsNotNone(executor)
- self.assertEqual("FakeExecutor", executor.__class__.__name__)
+ assert executor is not None
+ assert "FakeExecutor" == executor.__class__.__name__
def test_should_support_custom_path(self):
with conf_vars({("core", "executor"): "tests.executors.test_executor_loader.FakeExecutor"}):
executor = ExecutorLoader.get_default_executor()
- self.assertIsNotNone(executor)
- self.assertEqual("FakeExecutor", executor.__class__.__name__)
+ assert executor is not None
+ assert "FakeExecutor" == executor.__class__.__name__
diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py
index 9d8d72f4ef0b4..9abb32884310b 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -83,7 +83,7 @@ def _is_safe_label_value(value):
def test_create_pod_id(self):
for dag_id, task_id in self._cases():
pod_name = PodGenerator.make_unique_pod_id(create_pod_id(dag_id, task_id))
- self.assertTrue(self._is_valid_pod_id(pod_name))
+ assert self._is_valid_pod_id(pod_name)
@unittest.skipIf(AirflowKubernetesScheduler is None, 'kubernetes python package is not installed')
@mock.patch("airflow.kubernetes.pod_generator.PodGenerator")
@@ -91,32 +91,30 @@ def test_create_pod_id(self):
def test_get_base_pod_from_template(self, mock_kubeconfig, mock_generator):
pod_template_file_path = "/bar/biz"
get_base_pod_from_template(pod_template_file_path, None)
- self.assertEqual("deserialize_model_dict", mock_generator.mock_calls[0][0])
- self.assertEqual(pod_template_file_path, mock_generator.mock_calls[0][1][0])
+ assert "deserialize_model_dict" == mock_generator.mock_calls[0][0]
+ assert pod_template_file_path == mock_generator.mock_calls[0][1][0]
mock_kubeconfig.pod_template_file = "/foo/bar"
get_base_pod_from_template(None, mock_kubeconfig)
- self.assertEqual("deserialize_model_dict", mock_generator.mock_calls[1][0])
- self.assertEqual("/foo/bar", mock_generator.mock_calls[1][1][0])
+ assert "deserialize_model_dict" == mock_generator.mock_calls[1][0]
+ assert "/foo/bar" == mock_generator.mock_calls[1][1][0]
def test_make_safe_label_value(self):
for dag_id, task_id in self._cases():
safe_dag_id = pod_generator.make_safe_label_value(dag_id)
- self.assertTrue(self._is_safe_label_value(safe_dag_id))
+ assert self._is_safe_label_value(safe_dag_id)
safe_task_id = pod_generator.make_safe_label_value(task_id)
- self.assertTrue(self._is_safe_label_value(safe_task_id))
+ assert self._is_safe_label_value(safe_task_id)
dag_id = "my_dag_id"
- self.assertEqual(dag_id, pod_generator.make_safe_label_value(dag_id))
+ assert dag_id == pod_generator.make_safe_label_value(dag_id)
dag_id = "my_dag_id_" + "a" * 64
- self.assertEqual(
- "my_dag_id_" + "a" * 43 + "-0ce114c45", pod_generator.make_safe_label_value(dag_id)
- )
+ assert "my_dag_id_" + "a" * 43 + "-0ce114c45" == pod_generator.make_safe_label_value(dag_id)
def test_execution_date_serialize_deserialize(self):
datetime_obj = datetime.now()
serialized_datetime = pod_generator.datetime_to_label_safe_datestring(datetime_obj)
new_datetime_obj = pod_generator.label_safe_datestring_to_datetime(serialized_datetime)
- self.assertEqual(datetime_obj, new_datetime_obj)
+ assert datetime_obj == new_datetime_obj
class TestKubernetesExecutor(unittest.TestCase):
@@ -172,7 +170,7 @@ def test_run_next_exception(self, mock_get_kube_client, mock_kubernetes_job_watc
kubernetes_executor.sync()
assert mock_kube_client.create_namespaced_pod.called
- self.assertFalse(kubernetes_executor.task_queue.empty())
+ assert not kubernetes_executor.task_queue.empty()
# Disable the ApiException
mock_kube_client.create_namespaced_pod.side_effect = None
@@ -180,7 +178,7 @@ def test_run_next_exception(self, mock_get_kube_client, mock_kubernetes_job_watc
# Execute the task without errors should empty the queue
kubernetes_executor.sync()
assert mock_kube_client.create_namespaced_pod.called
- self.assertTrue(kubernetes_executor.task_queue.empty())
+ assert kubernetes_executor.task_queue.empty()
@mock.patch('airflow.executors.kubernetes_executor.KubeConfig')
@mock.patch('airflow.executors.kubernetes_executor.KubernetesExecutor.sync')
@@ -203,7 +201,7 @@ def test_change_state_running(self, mock_get_kube_client, mock_kubernetes_job_wa
executor.start()
key = ('dag_id', 'task_id', 'ex_time', 'try_number1')
executor._change_state(key, State.RUNNING, 'pod_id', 'default')
- self.assertTrue(executor.event_buffer[key][0] == State.RUNNING)
+ assert executor.event_buffer[key][0] == State.RUNNING
@mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher')
@mock.patch('airflow.executors.kubernetes_executor.get_kube_client')
@@ -214,7 +212,7 @@ def test_change_state_success(self, mock_delete_pod, mock_get_kube_client, mock_
test_time = timezone.utcnow()
key = ('dag_id', 'task_id', test_time, 'try_number2')
executor._change_state(key, State.SUCCESS, 'pod_id', 'default')
- self.assertTrue(executor.event_buffer[key][0] == State.SUCCESS)
+ assert executor.event_buffer[key][0] == State.SUCCESS
mock_delete_pod.assert_called_once_with('pod_id', 'default')
@mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher')
@@ -230,7 +228,7 @@ def test_change_state_failed_no_deletion(
test_time = timezone.utcnow()
key = ('dag_id', 'task_id', test_time, 'try_number3')
executor._change_state(key, State.FAILED, 'pod_id', 'default')
- self.assertTrue(executor.event_buffer[key][0] == State.FAILED)
+ assert executor.event_buffer[key][0] == State.FAILED
mock_delete_pod.assert_not_called()
# pylint: enable=unused-argument
@@ -249,7 +247,7 @@ def test_change_state_skip_pod_deletion(
executor.start()
key = ('dag_id', 'task_id', test_time, 'try_number2')
executor._change_state(key, State.SUCCESS, 'pod_id', 'default')
- self.assertTrue(executor.event_buffer[key][0] == State.SUCCESS)
+ assert executor.event_buffer[key][0] == State.SUCCESS
mock_delete_pod.assert_not_called()
@mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher')
@@ -264,7 +262,7 @@ def test_change_state_failed_pod_deletion(
executor.start()
key = ('dag_id', 'task_id', 'ex_time', 'try_number2')
executor._change_state(key, State.FAILED, 'pod_id', 'test-namespace')
- self.assertTrue(executor.event_buffer[key][0] == State.FAILED)
+ assert executor.event_buffer[key][0] == State.FAILED
mock_delete_pod.assert_called_once_with('pod_id', 'test-namespace')
@mock.patch('airflow.executors.kubernetes_executor.get_kube_client')
@@ -278,20 +276,17 @@ def test_adopt_launched_task(self, mock_kube_client):
)
)
executor.adopt_launched_task(mock_kube_client, pod=pod, pod_ids=pod_ids)
- self.assertEqual(
- mock_kube_client.patch_namespaced_pod.call_args[1],
- {
- 'body': {
- 'metadata': {
- 'labels': {'airflow-worker': 'modified', 'dag_id': 'dag', 'task_id': 'task'},
- 'name': 'foo',
- }
- },
- 'name': 'foo',
- 'namespace': None,
+ assert mock_kube_client.patch_namespaced_pod.call_args[1] == {
+ 'body': {
+ 'metadata': {
+ 'labels': {'airflow-worker': 'modified', 'dag_id': 'dag', 'task_id': 'task'},
+ 'name': 'foo',
+ }
},
- )
- self.assertDictEqual(pod_ids, {})
+ 'name': 'foo',
+ 'namespace': None,
+ }
+ assert pod_ids == {}
@mock.patch('airflow.executors.kubernetes_executor.get_kube_client')
def test_not_adopt_unassigned_task(self, mock_kube_client):
@@ -311,5 +306,5 @@ def test_not_adopt_unassigned_task(self, mock_kube_client):
)
)
executor.adopt_launched_task(mock_kube_client, pod=pod, pod_ids=pod_ids)
- self.assertFalse(mock_kube_client.patch_namespaced_pod.called)
- self.assertDictEqual(pod_ids, {"foobar": {}})
+ assert not mock_kube_client.patch_namespaced_pod.called
+ assert pod_ids == {"foobar": {}}
diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py
index b12f783c1b5f7..92ddd026afa90 100644
--- a/tests/executors/test_local_executor.py
+++ b/tests/executors/test_local_executor.py
@@ -64,7 +64,7 @@ def _test_execute(self, parallelism, success_command, fail_command):
executor.start()
success_key = 'success {}'
- self.assertTrue(executor.result_queue.empty())
+ assert executor.result_queue.empty()
execution_date = datetime.datetime.now()
for i in range(self.TEST_SUCCESS_COMMANDS):
@@ -79,16 +79,16 @@ def _test_execute(self, parallelism, success_command, fail_command):
executor.end()
# By that time Queues are already shutdown so we cannot check if they are empty
- self.assertEqual(len(executor.running), 0)
+ assert len(executor.running) == 0
for i in range(self.TEST_SUCCESS_COMMANDS):
key_id = success_key.format(i)
key = key_id, 'fake_ti', execution_date, 0
- self.assertEqual(executor.event_buffer[key][0], State.SUCCESS)
- self.assertEqual(executor.event_buffer[fail_key][0], State.FAILED)
+ assert executor.event_buffer[key][0] == State.SUCCESS
+ assert executor.event_buffer[fail_key][0] == State.FAILED
expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism
- self.assertEqual(executor.workers_used, expected)
+ assert executor.workers_used == expected
def test_execution_subprocess_unlimited_parallelism(self):
with mock.patch.object(
diff --git a/tests/hooks/test_dbapi.py b/tests/hooks/test_dbapi.py
index fb4ee1c179640..2cc916d507fd0 100644
--- a/tests/hooks/test_dbapi.py
+++ b/tests/hooks/test_dbapi.py
@@ -20,6 +20,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.hooks.dbapi import DbApiHook
from airflow.models import Connection
@@ -48,7 +50,7 @@ def test_get_records(self):
self.cur.fetchall.return_value = rows
- self.assertEqual(rows, self.db_hook.get_records(statement))
+ assert rows == self.db_hook.get_records(statement)
assert self.conn.close.call_count == 1
assert self.cur.close.call_count == 1
@@ -61,7 +63,7 @@ def test_get_records_parameters(self):
self.cur.fetchall.return_value = rows
- self.assertEqual(rows, self.db_hook.get_records(statement, parameters))
+ assert rows == self.db_hook.get_records(statement, parameters)
assert self.conn.close.call_count == 1
assert self.cur.close.call_count == 1
@@ -71,7 +73,7 @@ def test_get_records_exception(self):
statement = "SQL"
self.cur.fetchall.side_effect = RuntimeError('Great Problems')
- with self.assertRaises(RuntimeError):
+ with pytest.raises(RuntimeError):
self.db_hook.get_records(statement)
assert self.conn.close.call_count == 1
@@ -88,7 +90,7 @@ def test_insert_rows(self):
assert self.cur.close.call_count == 1
commit_count = 2 # The first and last commit
- self.assertEqual(commit_count, self.conn.commit.call_count)
+ assert commit_count == self.conn.commit.call_count
sql = f"INSERT INTO {table} VALUES (%s)"
for row in rows:
@@ -104,7 +106,7 @@ def test_insert_rows_replace(self):
assert self.cur.close.call_count == 1
commit_count = 2 # The first and last commit
- self.assertEqual(commit_count, self.conn.commit.call_count)
+ assert commit_count == self.conn.commit.call_count
sql = f"REPLACE INTO {table} VALUES (%s)"
for row in rows:
@@ -121,7 +123,7 @@ def test_insert_rows_target_fields(self):
assert self.cur.close.call_count == 1
commit_count = 2 # The first and last commit
- self.assertEqual(commit_count, self.conn.commit.call_count)
+ assert commit_count == self.conn.commit.call_count
sql = "INSERT INTO {} ({}) VALUES (%s)".format(table, target_fields[0])
for row in rows:
@@ -138,7 +140,7 @@ def test_insert_rows_commit_every(self):
assert self.cur.close.call_count == 1
commit_count = 2 + divmod(len(rows), commit_every)[0]
- self.assertEqual(commit_count, self.conn.commit.call_count)
+ assert commit_count == self.conn.commit.call_count
sql = f"INSERT INTO {table} VALUES (%s)"
for row in rows:
@@ -155,7 +157,7 @@ def test_get_uri_schema_not_none(self):
port=1,
)
)
- self.assertEqual("conn_type://login:password@host:1/schema", self.db_hook.get_uri())
+ assert "conn_type://login:password@host:1/schema" == self.db_hook.get_uri()
def test_get_uri_schema_none(self):
self.db_hook.get_connection = mock.MagicMock(
@@ -163,7 +165,7 @@ def test_get_uri_schema_none(self):
conn_type="conn_type", host="host", login="login", password="password", schema=None, port=1
)
)
- self.assertEqual("conn_type://login:password@host:1/", self.db_hook.get_uri())
+ assert "conn_type://login:password@host:1/" == self.db_hook.get_uri()
def test_get_uri_special_characters(self):
self.db_hook.get_connection = mock.MagicMock(
@@ -176,7 +178,7 @@ def test_get_uri_special_characters(self):
port=1,
)
)
- self.assertEqual("conn_type://logi%23%21+n:pass%2A%21+word@host:1/schema", self.db_hook.get_uri())
+ assert "conn_type://logi%23%21+n:pass%2A%21+word@host:1/schema" == self.db_hook.get_uri()
def test_run_log(self):
statement = 'SQL'
diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py
index 43cc992f16f58..aa221f52372c6 100644
--- a/tests/jobs/test_backfill_job.py
+++ b/tests/jobs/test_backfill_job.py
@@ -107,7 +107,7 @@ def test_unfinished_dag_runs_set_to_failed(self):
dag_run.refresh_from_db()
- self.assertEqual(State.FAILED, dag_run.state)
+ assert State.FAILED == dag_run.state
def test_dag_run_with_finished_tasks_set_to_success(self):
dag = self._get_dummy_dag('dummy_dag')
@@ -131,7 +131,7 @@ def test_dag_run_with_finished_tasks_set_to_success(self):
dag_run.refresh_from_db()
- self.assertEqual(State.SUCCESS, dag_run.state)
+ assert State.SUCCESS == dag_run.state
@pytest.mark.xfail(condition=True, reason="This test is flaky")
@pytest.mark.backend("postgres", "mysql")
@@ -146,7 +146,7 @@ def test_trigger_controller_dag(self):
# target_dag,
# dag_runs=DagRun.find(dag_id='example_trigger_target_dag')
# )
- self.assertFalse(task_instances_list)
+ assert not task_instances_list
job = BackfillJob(
dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_first_depends_on_past=True
@@ -159,7 +159,7 @@ def test_trigger_controller_dag(self):
# dag_runs=DagRun.find(dag_id='example_trigger_target_dag')
# )
- self.assertTrue(task_instances_list)
+ assert task_instances_list
@pytest.mark.backend("postgres", "mysql")
def test_backfill_multi_dates(self):
@@ -192,21 +192,18 @@ def test_backfill_multi_dates(self):
("run_this_last", DEFAULT_DATE),
("run_this_last", end_date),
]
- self.assertListEqual(
- [
- ((dag.dag_id, task_id, when, 1), (State.SUCCESS, None))
- for (task_id, when) in expected_execution_order
- ],
- executor.sorted_tasks,
- )
+ assert [
+ ((dag.dag_id, task_id, when, 1), (State.SUCCESS, None))
+ for (task_id, when) in expected_execution_order
+ ] == executor.sorted_tasks
session = settings.Session()
drs = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).order_by(DagRun.execution_date).all()
- self.assertTrue(drs[0].execution_date == DEFAULT_DATE)
- self.assertTrue(drs[0].state == State.SUCCESS)
- self.assertTrue(drs[1].execution_date == DEFAULT_DATE + datetime.timedelta(days=1))
- self.assertTrue(drs[1].state == State.SUCCESS)
+ assert drs[0].execution_date == DEFAULT_DATE
+ assert drs[0].state == State.SUCCESS
+ assert drs[1].execution_date == DEFAULT_DATE + datetime.timedelta(days=1)
+ assert drs[1].state == State.SUCCESS
dag.clear()
session.close()
@@ -271,13 +268,10 @@ def test_backfill_examples(self, dag_id, expected_execution_order):
)
job.run()
- self.assertListEqual(
- [
- ((dag_id, task_id, DEFAULT_DATE, 1), (State.SUCCESS, None))
- for task_id in expected_execution_order
- ],
- executor.sorted_tasks,
- )
+ assert [
+ ((dag_id, task_id, DEFAULT_DATE, 1), (State.SUCCESS, None))
+ for task_id in expected_execution_order
+ ] == executor.sorted_tasks
def test_backfill_conf(self):
dag = self._get_dummy_dag('test_backfill_conf')
@@ -296,7 +290,7 @@ def test_backfill_conf(self):
dr = DagRun.find(dag_id='test_backfill_conf')
- self.assertEqual(conf_, dr[0].conf)
+ assert conf_ == dr[0].conf
@patch('airflow.jobs.backfill_job.BackfillJob.log')
def test_backfill_respect_task_concurrency_limit(self, mock_log):
@@ -317,19 +311,19 @@ def test_backfill_respect_task_concurrency_limit(self, mock_log):
job.run()
- self.assertGreater(len(executor.history), 0)
+ assert len(executor.history) > 0
task_concurrency_limit_reached_at_least_once = False
num_running_task_instances = 0
for running_task_instances in executor.history:
- self.assertLessEqual(len(running_task_instances), task_concurrency)
+ assert len(running_task_instances) <= task_concurrency
num_running_task_instances += len(running_task_instances)
if len(running_task_instances) == task_concurrency:
task_concurrency_limit_reached_at_least_once = True
- self.assertEqual(8, num_running_task_instances)
- self.assertTrue(task_concurrency_limit_reached_at_least_once)
+ assert 8 == num_running_task_instances
+ assert task_concurrency_limit_reached_at_least_once
times_dag_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
@@ -346,9 +340,9 @@ def test_backfill_respect_task_concurrency_limit(self, mock_log):
TaskConcurrencyLimitReached,
)
- self.assertEqual(0, times_pool_limit_reached_in_debug)
- self.assertEqual(0, times_dag_concurrency_limit_reached_in_debug)
- self.assertGreater(times_task_concurrency_limit_reached_in_debug, 0)
+ assert 0 == times_pool_limit_reached_in_debug
+ assert 0 == times_dag_concurrency_limit_reached_in_debug
+ assert times_task_concurrency_limit_reached_in_debug > 0
@patch('airflow.jobs.backfill_job.BackfillJob.log')
def test_backfill_respect_dag_concurrency_limit(self, mock_log):
@@ -367,20 +361,20 @@ def test_backfill_respect_dag_concurrency_limit(self, mock_log):
job.run()
- self.assertGreater(len(executor.history), 0)
+ assert len(executor.history) > 0
concurrency_limit_reached_at_least_once = False
num_running_task_instances = 0
for running_task_instances in executor.history:
- self.assertLessEqual(len(running_task_instances), dag.concurrency)
+ assert len(running_task_instances) <= dag.concurrency
num_running_task_instances += len(running_task_instances)
if len(running_task_instances) == dag.concurrency:
concurrency_limit_reached_at_least_once = True
- self.assertEqual(8, num_running_task_instances)
- self.assertTrue(concurrency_limit_reached_at_least_once)
+ assert 8 == num_running_task_instances
+ assert concurrency_limit_reached_at_least_once
times_dag_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
@@ -397,9 +391,9 @@ def test_backfill_respect_dag_concurrency_limit(self, mock_log):
TaskConcurrencyLimitReached,
)
- self.assertEqual(0, times_pool_limit_reached_in_debug)
- self.assertEqual(0, times_task_concurrency_limit_reached_in_debug)
- self.assertGreater(times_dag_concurrency_limit_reached_in_debug, 0)
+ assert 0 == times_pool_limit_reached_in_debug
+ assert 0 == times_task_concurrency_limit_reached_in_debug
+ assert times_dag_concurrency_limit_reached_in_debug > 0
@patch('airflow.jobs.backfill_job.BackfillJob.log')
def test_backfill_respect_default_pool_limit(self, mock_log):
@@ -419,7 +413,7 @@ def test_backfill_respect_default_pool_limit(self, mock_log):
job.run()
- self.assertGreater(len(executor.history), 0)
+ assert len(executor.history) > 0
default_pool_task_slot_count_reached_at_least_once = False
@@ -429,16 +423,13 @@ def test_backfill_respect_default_pool_limit(self, mock_log):
# parallel per backfill should be less than
# default_pool slots at any point of time.
for running_task_instances in executor.history:
- self.assertLessEqual(
- len(running_task_instances),
- default_pool_slots,
- )
+ assert len(running_task_instances) <= default_pool_slots
num_running_task_instances += len(running_task_instances)
if len(running_task_instances) == default_pool_slots:
default_pool_task_slot_count_reached_at_least_once = True
- self.assertEqual(8, num_running_task_instances)
- self.assertTrue(default_pool_task_slot_count_reached_at_least_once)
+ assert 8 == num_running_task_instances
+ assert default_pool_task_slot_count_reached_at_least_once
times_dag_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
@@ -455,9 +446,9 @@ def test_backfill_respect_default_pool_limit(self, mock_log):
TaskConcurrencyLimitReached,
)
- self.assertEqual(0, times_dag_concurrency_limit_reached_in_debug)
- self.assertEqual(0, times_task_concurrency_limit_reached_in_debug)
- self.assertGreater(times_pool_limit_reached_in_debug, 0)
+ assert 0 == times_dag_concurrency_limit_reached_in_debug
+ assert 0 == times_task_concurrency_limit_reached_in_debug
+ assert times_pool_limit_reached_in_debug > 0
def test_backfill_pool_not_found(self):
dag = self._get_dummy_dag(
@@ -509,19 +500,19 @@ def test_backfill_respect_pool_limit(self, mock_log):
job.run()
- self.assertGreater(len(executor.history), 0)
+ assert len(executor.history) > 0
pool_was_full_at_least_once = False
num_running_task_instances = 0
for running_task_instances in executor.history:
- self.assertLessEqual(len(running_task_instances), slots)
+ assert len(running_task_instances) <= slots
num_running_task_instances += len(running_task_instances)
if len(running_task_instances) == slots:
pool_was_full_at_least_once = True
- self.assertEqual(8, num_running_task_instances)
- self.assertTrue(pool_was_full_at_least_once)
+ assert 8 == num_running_task_instances
+ assert pool_was_full_at_least_once
times_dag_concurrency_limit_reached_in_debug = self._times_called_with(
mock_log.debug,
@@ -538,9 +529,9 @@ def test_backfill_respect_pool_limit(self, mock_log):
TaskConcurrencyLimitReached,
)
- self.assertEqual(0, times_task_concurrency_limit_reached_in_debug)
- self.assertEqual(0, times_dag_concurrency_limit_reached_in_debug)
- self.assertGreater(times_pool_limit_reached_in_debug, 0)
+ assert 0 == times_task_concurrency_limit_reached_in_debug
+ assert 0 == times_dag_concurrency_limit_reached_in_debug
+ assert times_pool_limit_reached_in_debug > 0
def test_backfill_run_rescheduled(self):
dag = DAG(dag_id='test_backfill_run_rescheduled', start_date=DEFAULT_DATE, schedule_interval='@daily')
@@ -577,7 +568,7 @@ def test_backfill_run_rescheduled(self):
job.run()
ti = TI(task=dag.get_task('test_backfill_run_rescheduled_task-1'), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
def test_backfill_rerun_failed_tasks(self):
dag = DAG(dag_id='test_backfill_rerun_failed', start_date=DEFAULT_DATE, schedule_interval='@daily')
@@ -611,7 +602,7 @@ def test_backfill_rerun_failed_tasks(self):
job.run()
ti = TI(task=dag.get_task('test_backfill_rerun_failed_task-1'), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
def test_backfill_rerun_upstream_failed_tasks(self):
dag = DAG(
@@ -648,7 +639,7 @@ def test_backfill_rerun_upstream_failed_tasks(self):
job.run()
ti = TI(task=dag.get_task('test_backfill_rerun_upstream_failed_task-1'), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
def test_backfill_rerun_failed_tasks_without_flag(self):
dag = DAG(dag_id='test_backfill_rerun_failed', start_date=DEFAULT_DATE, schedule_interval='@daily')
@@ -680,7 +671,7 @@ def test_backfill_rerun_failed_tasks_without_flag(self):
rerun_failed_tasks=False,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
job.run()
def test_backfill_ordered_concurrent_execute(self):
@@ -720,23 +711,19 @@ def test_backfill_ordered_concurrent_execute(self):
# test executor history keeps a list
history = executor.history
- self.assertListEqual(
- # key[0] is dag id, key[3] is try_number, we don't care about either of those here
- [sorted([item[-1].key[1:3] for item in batch]) for batch in history],
+ assert [sorted([item[-1].key[1:3] for item in batch]) for batch in history] == [
[
- [
- ('leave1', date0),
- ('leave1', date1),
- ('leave1', date2),
- ('leave2', date0),
- ('leave2', date1),
- ('leave2', date2),
- ],
- [('upstream_level_1', date0), ('upstream_level_1', date1), ('upstream_level_1', date2)],
- [('upstream_level_2', date0), ('upstream_level_2', date1), ('upstream_level_2', date2)],
- [('upstream_level_3', date0), ('upstream_level_3', date1), ('upstream_level_3', date2)],
+ ('leave1', date0),
+ ('leave1', date1),
+ ('leave1', date2),
+ ('leave2', date0),
+ ('leave2', date1),
+ ('leave2', date2),
],
- )
+ [('upstream_level_1', date0), ('upstream_level_1', date1), ('upstream_level_1', date2)],
+ [('upstream_level_2', date0), ('upstream_level_2', date1), ('upstream_level_2', date2)],
+ [('upstream_level_3', date0), ('upstream_level_3', date1), ('upstream_level_3', date2)],
+ ]
def test_backfill_pooled_tasks(self):
"""
@@ -763,7 +750,7 @@ def test_backfill_pooled_tasks(self):
pass
ti = TI(task=dag.get_task('test_backfill_pooled_task'), execution_date=DEFAULT_DATE)
ti.refresh_from_db()
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
def test_backfill_depends_on_past(self):
"""
@@ -774,11 +761,8 @@ def test_backfill_depends_on_past(self):
run_date = DEFAULT_DATE + datetime.timedelta(days=5)
# backfill should deadlock
- self.assertRaisesRegex(
- AirflowException,
- 'BackfillJob is deadlocked',
- BackfillJob(dag=dag, start_date=run_date, end_date=run_date).run,
- )
+ with pytest.raises(AirflowException, match='BackfillJob is deadlocked'):
+ BackfillJob(dag=dag, start_date=run_date, end_date=run_date).run()
BackfillJob(
dag=dag,
@@ -791,7 +775,7 @@ def test_backfill_depends_on_past(self):
# ti should have succeeded
ti = TI(dag.tasks[0], run_date)
ti.refresh_from_db()
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
def test_backfill_depends_on_past_backwards(self):
"""
@@ -814,13 +798,13 @@ def test_backfill_depends_on_past_backwards(self):
ti = TI(dag.get_task('test_dop_task'), end_date)
ti.refresh_from_db()
# runs fine forwards
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
# raises backwards
expected_msg = 'You cannot backfill backwards because one or more tasks depend_on_past: {}'.format(
'test_dop_task'
)
- with self.assertRaisesRegex(AirflowException, expected_msg):
+ with pytest.raises(AirflowException, match=expected_msg):
executor = MockExecutor()
job = BackfillJob(dag=dag, executor=executor, run_backwards=True, **kwargs)
job.run()
@@ -841,7 +825,7 @@ def test_cli_receives_delay_arg(self):
'0.5',
]
parsed_args = self.parser.parse_args(args)
- self.assertEqual(0.5, parsed_args.delay_on_limit)
+ assert 0.5 == parsed_args.delay_on_limit
def _get_dag_test_max_active_limits(self, dag_id, max_active_runs=1):
dag = DAG(
@@ -878,8 +862,8 @@ def test_backfill_max_limit_check_within_limit(self):
job.run()
dagruns = DagRun.find(dag_id=dag.dag_id)
- self.assertEqual(2, len(dagruns))
- self.assertTrue(all(run.state == State.SUCCESS for run in dagruns))
+ assert 2 == len(dagruns)
+ assert all(run.state == State.SUCCESS for run in dagruns)
def test_backfill_max_limit_check(self):
dag_id = 'test_backfill_max_limit_check'
@@ -929,8 +913,8 @@ def run_backfill(cond):
dag_run_created_cond.wait(timeout=1.5)
dagruns = DagRun.find(dag_id=dag_id)
dr = dagruns[0]
- self.assertEqual(1, len(dagruns))
- self.assertEqual(dr.run_id, run_id)
+ assert 1 == len(dagruns)
+ assert dr.run_id == run_id
# allow the backfill to execute
# by setting the existing dag run to SUCCESS,
@@ -942,8 +926,8 @@ def run_backfill(cond):
backfill_job_thread.join()
dagruns = DagRun.find(dag_id=dag_id)
- self.assertEqual(3, len(dagruns)) # 2 from backfill + 1 existing
- self.assertEqual(dagruns[-1].run_id, dr.run_id)
+ assert 3 == len(dagruns) # 2 from backfill + 1 existing
+ assert dagruns[-1].run_id == dr.run_id
finally:
dag_run_created_cond.release()
@@ -971,8 +955,8 @@ def test_backfill_max_limit_check_no_count_existing(self):
dagruns = DagRun.find(dag_id=dag.dag_id)
# will only be able to run 1 (the existing one) since there's just
# one dag run slot left given the max_active_runs limit
- self.assertEqual(1, len(dagruns))
- self.assertEqual(State.SUCCESS, dagruns[0].state)
+ assert 1 == len(dagruns)
+ assert State.SUCCESS == dagruns[0].state
def test_backfill_max_limit_check_complete_loop(self):
dag = self._get_dag_test_max_active_limits('test_backfill_max_limit_check_complete_loop')
@@ -990,8 +974,8 @@ def test_backfill_max_limit_check_complete_loop(self):
success_dagruns = len(DagRun.find(dag_id=dag.dag_id, state=State.SUCCESS))
running_dagruns = len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING))
- self.assertEqual(success_expected, success_dagruns)
- self.assertEqual(0, running_dagruns) # no dag_runs in running state are left
+ assert success_expected == success_dagruns
+ assert 0 == running_dagruns # no dag_runs in running state are left
def test_sub_set_subdag(self):
dag = DAG('test_sub_set_subdag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
@@ -1018,17 +1002,18 @@ def test_sub_set_subdag(self):
job = BackfillJob(dag=sub_dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor)
job.run()
- self.assertRaises(sqlalchemy.orm.exc.NoResultFound, dr.refresh_from_db)
+ with pytest.raises(sqlalchemy.orm.exc.NoResultFound):
+ dr.refresh_from_db()
# the run_id should have changed, so a refresh won't work
drs = DagRun.find(dag_id=dag.dag_id, execution_date=DEFAULT_DATE)
dr = drs[0]
- self.assertEqual(DagRun.generate_run_id(DagRunType.BACKFILL_JOB, DEFAULT_DATE), dr.run_id)
+ assert DagRun.generate_run_id(DagRunType.BACKFILL_JOB, DEFAULT_DATE) == dr.run_id
for ti in dr.get_task_instances():
if ti.task_id == 'leave1' or ti.task_id == 'leave2':
- self.assertEqual(State.SUCCESS, ti.state)
+ assert State.SUCCESS == ti.state
else:
- self.assertEqual(State.NONE, ti.state)
+ assert State.NONE == ti.state
def test_backfill_fill_blanks(self):
dag = DAG(
@@ -1072,25 +1057,27 @@ def test_backfill_fill_blanks(self):
session.close()
job = BackfillJob(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor)
- self.assertRaisesRegex(AirflowException, 'Some task instances failed', job.run)
+ with pytest.raises(AirflowException, match='Some task instances failed'):
+ job.run()
- self.assertRaises(sqlalchemy.orm.exc.NoResultFound, dr.refresh_from_db)
+ with pytest.raises(sqlalchemy.orm.exc.NoResultFound):
+ dr.refresh_from_db()
# the run_id should have changed, so a refresh won't work
drs = DagRun.find(dag_id=dag.dag_id, execution_date=DEFAULT_DATE)
dr = drs[0]
- self.assertEqual(dr.state, State.FAILED)
+ assert dr.state == State.FAILED
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id in (op1.task_id, op4.task_id, op6.task_id):
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == op2.task_id:
- self.assertEqual(ti.state, State.FAILED)
+ assert ti.state == State.FAILED
elif ti.task_id == op3.task_id:
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
elif ti.task_id == op5.task_id:
- self.assertEqual(ti.state, State.UPSTREAM_FAILED)
+ assert ti.state == State.UPSTREAM_FAILED
def test_backfill_execute_subdag(self):
dag = self.dagbag.get_dag('example_subdag_operator')
@@ -1114,10 +1101,10 @@ def test_backfill_execute_subdag(self):
subdag_history = history[0]
# check that all 5 task instances of the subdag 'section-1' were executed
- self.assertEqual(5, len(subdag_history))
+ assert 5 == len(subdag_history)
for sdh in subdag_history:
ti = sdh[3]
- self.assertIn('section-1-task-', ti.task_id)
+ assert 'section-1-task-' in ti.task_id
with create_session() as session:
successful_subdag_runs = (
@@ -1129,7 +1116,7 @@ def test_backfill_execute_subdag(self):
.count()
)
- self.assertEqual(1, successful_subdag_runs)
+ assert 1 == successful_subdag_runs
subdag.clear()
dag.clear()
@@ -1150,15 +1137,15 @@ def test_subdag_clear_parentdag_downstream_clear(self):
ti_subdag = TI(task=dag.get_task('daily_job'), execution_date=DEFAULT_DATE)
ti_subdag.refresh_from_db()
- self.assertEqual(ti_subdag.state, State.SUCCESS)
+ assert ti_subdag.state == State.SUCCESS
ti_irrelevant = TI(task=dag.get_task('daily_job_irrelevant'), execution_date=DEFAULT_DATE)
ti_irrelevant.refresh_from_db()
- self.assertEqual(ti_irrelevant.state, State.SUCCESS)
+ assert ti_irrelevant.state == State.SUCCESS
ti_downstream = TI(task=dag.get_task('daily_job_downstream'), execution_date=DEFAULT_DATE)
ti_downstream.refresh_from_db()
- self.assertEqual(ti_downstream.state, State.SUCCESS)
+ assert ti_downstream.state == State.SUCCESS
sdag = subdag.sub_dag(
task_ids_or_regex='daily_job_subdag_task', include_downstream=True, include_upstream=False
@@ -1167,13 +1154,13 @@ def test_subdag_clear_parentdag_downstream_clear(self):
sdag.clear(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, include_parentdag=True)
ti_subdag.refresh_from_db()
- self.assertEqual(State.NONE, ti_subdag.state)
+ assert State.NONE == ti_subdag.state
ti_irrelevant.refresh_from_db()
- self.assertEqual(State.SUCCESS, ti_irrelevant.state)
+ assert State.SUCCESS == ti_irrelevant.state
ti_downstream.refresh_from_db()
- self.assertEqual(State.NONE, ti_downstream.state)
+ assert State.NONE == ti_downstream.state
subdag.clear()
dag.clear()
@@ -1213,11 +1200,11 @@ def test_backfill_execute_subdag_with_removed_task(self):
.first()
)
- self.assertIsNotNone(instance)
- self.assertEqual(instance.state, State.SUCCESS)
+ assert instance is not None
+ assert instance.state == State.SUCCESS
removed_task_ti.refresh_from_db()
- self.assertEqual(removed_task_ti.state, State.REMOVED)
+ assert removed_task_ti.state == State.REMOVED
subdag.clear()
dag.clear()
@@ -1246,11 +1233,11 @@ def test_update_counters(self):
ti.set_state(State.SUCCESS, session)
ti_status.running[ti.key] = ti
job._update_counters(ti_status=ti_status)
- self.assertTrue(len(ti_status.running) == 0)
- self.assertTrue(len(ti_status.succeeded) == 1)
- self.assertTrue(len(ti_status.skipped) == 0)
- self.assertTrue(len(ti_status.failed) == 0)
- self.assertTrue(len(ti_status.to_run) == 0)
+ assert len(ti_status.running) == 0
+ assert len(ti_status.succeeded) == 1
+ assert len(ti_status.skipped) == 0
+ assert len(ti_status.failed) == 0
+ assert len(ti_status.to_run) == 0
ti_status.succeeded.clear()
@@ -1258,11 +1245,11 @@ def test_update_counters(self):
ti.set_state(State.SKIPPED, session)
ti_status.running[ti.key] = ti
job._update_counters(ti_status=ti_status)
- self.assertTrue(len(ti_status.running) == 0)
- self.assertTrue(len(ti_status.succeeded) == 0)
- self.assertTrue(len(ti_status.skipped) == 1)
- self.assertTrue(len(ti_status.failed) == 0)
- self.assertTrue(len(ti_status.to_run) == 0)
+ assert len(ti_status.running) == 0
+ assert len(ti_status.succeeded) == 0
+ assert len(ti_status.skipped) == 1
+ assert len(ti_status.failed) == 0
+ assert len(ti_status.to_run) == 0
ti_status.skipped.clear()
@@ -1270,11 +1257,11 @@ def test_update_counters(self):
ti.set_state(State.FAILED, session)
ti_status.running[ti.key] = ti
job._update_counters(ti_status=ti_status)
- self.assertTrue(len(ti_status.running) == 0)
- self.assertTrue(len(ti_status.succeeded) == 0)
- self.assertTrue(len(ti_status.skipped) == 0)
- self.assertTrue(len(ti_status.failed) == 1)
- self.assertTrue(len(ti_status.to_run) == 0)
+ assert len(ti_status.running) == 0
+ assert len(ti_status.succeeded) == 0
+ assert len(ti_status.skipped) == 0
+ assert len(ti_status.failed) == 1
+ assert len(ti_status.to_run) == 0
ti_status.failed.clear()
@@ -1282,11 +1269,11 @@ def test_update_counters(self):
ti.set_state(State.UP_FOR_RETRY, session)
ti_status.running[ti.key] = ti
job._update_counters(ti_status=ti_status)
- self.assertTrue(len(ti_status.running) == 0)
- self.assertTrue(len(ti_status.succeeded) == 0)
- self.assertTrue(len(ti_status.skipped) == 0)
- self.assertTrue(len(ti_status.failed) == 0)
- self.assertTrue(len(ti_status.to_run) == 1)
+ assert len(ti_status.running) == 0
+ assert len(ti_status.succeeded) == 0
+ assert len(ti_status.skipped) == 0
+ assert len(ti_status.failed) == 0
+ assert len(ti_status.to_run) == 1
ti_status.to_run.clear()
@@ -1294,11 +1281,11 @@ def test_update_counters(self):
ti.set_state(State.UP_FOR_RESCHEDULE, session)
ti_status.running[ti.key] = ti
job._update_counters(ti_status=ti_status)
- self.assertTrue(len(ti_status.running) == 0)
- self.assertTrue(len(ti_status.succeeded) == 0)
- self.assertTrue(len(ti_status.skipped) == 0)
- self.assertTrue(len(ti_status.failed) == 0)
- self.assertTrue(len(ti_status.to_run) == 1)
+ assert len(ti_status.running) == 0
+ assert len(ti_status.succeeded) == 0
+ assert len(ti_status.skipped) == 0
+ assert len(ti_status.failed) == 0
+ assert len(ti_status.to_run) == 1
ti_status.to_run.clear()
@@ -1306,11 +1293,11 @@ def test_update_counters(self):
ti.set_state(State.NONE, session)
ti_status.running[ti.key] = ti
job._update_counters(ti_status=ti_status)
- self.assertTrue(len(ti_status.running) == 0)
- self.assertTrue(len(ti_status.succeeded) == 0)
- self.assertTrue(len(ti_status.skipped) == 0)
- self.assertTrue(len(ti_status.failed) == 0)
- self.assertTrue(len(ti_status.to_run) == 1)
+ assert len(ti_status.running) == 0
+ assert len(ti_status.succeeded) == 0
+ assert len(ti_status.skipped) == 0
+ assert len(ti_status.failed) == 0
+ assert len(ti_status.to_run) == 1
ti_status.to_run.clear()
@@ -1327,22 +1314,17 @@ def get_test_dag_for_backfill(schedule_interval=None):
return dag
test_dag = get_test_dag_for_backfill()
- self.assertEqual(
- [DEFAULT_DATE], test_dag.get_run_dates(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- )
+ assert [DEFAULT_DATE] == test_dag.get_run_dates(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
test_dag = get_test_dag_for_backfill(schedule_interval="@hourly")
- self.assertEqual(
- [
- DEFAULT_DATE - datetime.timedelta(hours=3),
- DEFAULT_DATE - datetime.timedelta(hours=2),
- DEFAULT_DATE - datetime.timedelta(hours=1),
- DEFAULT_DATE,
- ],
- test_dag.get_run_dates(
- start_date=DEFAULT_DATE - datetime.timedelta(hours=3),
- end_date=DEFAULT_DATE,
- ),
+ assert [
+ DEFAULT_DATE - datetime.timedelta(hours=3),
+ DEFAULT_DATE - datetime.timedelta(hours=2),
+ DEFAULT_DATE - datetime.timedelta(hours=1),
+ DEFAULT_DATE,
+ ] == test_dag.get_run_dates(
+ start_date=DEFAULT_DATE - datetime.timedelta(hours=3),
+ end_date=DEFAULT_DATE,
)
def test_backfill_run_backwards(self):
@@ -1369,8 +1351,8 @@ def test_backfill_run_backwards(self):
)
queued_times = [ti.queued_dttm for ti in tis]
- self.assertTrue(queued_times == sorted(queued_times, reverse=True))
- self.assertTrue(all(ti.state == State.SUCCESS for ti in tis))
+ assert queued_times == sorted(queued_times, reverse=True)
+ assert all(ti.state == State.SUCCESS for ti in tis)
dag.clear()
session.close()
@@ -1411,7 +1393,7 @@ def test_reset_orphaned_tasks_with_orphans(self):
session.merge(ti2)
session.commit()
- self.assertEqual(2, job.reset_state_for_orphaned_tasks())
+ assert 2 == job.reset_state_for_orphaned_tasks()
for ti in dr1_tis + dr2_tis:
ti.refresh_from_db()
@@ -1419,13 +1401,13 @@ def test_reset_orphaned_tasks_with_orphans(self):
# running dagrun should be reset
for state, ti in zip(states, dr1_tis):
if state in states_to_reset:
- self.assertIsNone(ti.state)
+ assert ti.state is None
else:
- self.assertEqual(state, ti.state)
+ assert state == ti.state
# otherwise not
for state, ti in zip(states, dr2_tis):
- self.assertEqual(state, ti.state)
+ assert state == ti.state
for state, ti in zip(states, dr1_tis):
ti.state = state
@@ -1435,7 +1417,7 @@ def test_reset_orphaned_tasks_with_orphans(self):
# check same for dag_run version
for state, ti in zip(states, dr2_tis):
- self.assertEqual(state, ti.state)
+ assert state == ti.state
def test_reset_orphaned_tasks_specified_dagrun(self):
"""Try to reset when we specify a dagrun and ensure nothing else is."""
@@ -1461,11 +1443,11 @@ def test_reset_orphaned_tasks_specified_dagrun(self):
session.commit()
num_reset_tis = job.reset_state_for_orphaned_tasks(filter_by_dag_run=dr2, session=session)
- self.assertEqual(1, num_reset_tis)
+ assert 1 == num_reset_tis
ti1.refresh_from_db(session=session)
ti2.refresh_from_db(session=session)
- self.assertEqual(State.SCHEDULED, ti1.state)
- self.assertEqual(State.NONE, ti2.state)
+ assert State.SCHEDULED == ti1.state
+ assert State.NONE == ti2.state
def test_job_id_is_assigned_to_dag_run(self):
dag_id = 'test_job_id_is_assigned_to_dag_run'
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index dbc8aa55641be..fdd01633dcbb0 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -84,10 +84,10 @@ def test_localtaskjob_essential_attr(self):
essential_attr = ["dag_id", "job_type", "start_date", "hostname"]
check_result_1 = [hasattr(job1, attr) for attr in essential_attr]
- self.assertTrue(all(check_result_1))
+ assert all(check_result_1)
check_result_2 = [getattr(job1, attr) is not None for attr in essential_attr]
- self.assertTrue(all(check_result_2))
+ assert all(check_result_2)
@patch('os.getpid')
def test_localtaskjob_heartbeat(self, mock_pid):
@@ -111,7 +111,8 @@ def test_localtaskjob_heartbeat(self, mock_pid):
session.commit()
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
- self.assertRaises(AirflowException, job1.heartbeat_callback)
+ with pytest.raises(AirflowException):
+ job1.heartbeat_callback() # pylint: disable=no-value-for-parameter
mock_pid.return_value = 1
ti.state = State.RUNNING
@@ -123,7 +124,8 @@ def test_localtaskjob_heartbeat(self, mock_pid):
job1.heartbeat_callback(session=None)
mock_pid.return_value = 2
- self.assertRaises(AirflowException, job1.heartbeat_callback)
+ with pytest.raises(AirflowException):
+ job1.heartbeat_callback() # pylint: disable=no-value-for-parameter
def test_heartbeat_failed_fast(self):
"""
@@ -160,13 +162,13 @@ def test_heartbeat_failed_fast(self):
heartbeat_records = []
job.heartbeat_callback = lambda session: heartbeat_records.append(job.latest_heartbeat)
job._execute()
- self.assertGreater(len(heartbeat_records), 2)
+ assert len(heartbeat_records) > 2
for i in range(1, len(heartbeat_records)):
time1 = heartbeat_records[i - 1]
time2 = heartbeat_records[i]
# Assert that difference small enough
delta = (time2 - time1).total_seconds()
- self.assertAlmostEqual(delta, job.heartrate, delta=0.05)
+ assert abs(delta - job.heartrate) < 0.05
@pytest.mark.quarantined
def test_mark_success_no_kill(self):
@@ -202,15 +204,15 @@ def test_mark_success_no_kill(self):
break
time.sleep(0.1)
ti.refresh_from_db()
- self.assertEqual(State.RUNNING, ti.state)
+ assert State.RUNNING == ti.state
ti.state = State.SUCCESS
session.merge(ti)
session.commit()
process.join(timeout=10)
- self.assertFalse(process.is_alive())
+ assert not process.is_alive()
ti.refresh_from_db()
- self.assertEqual(State.SUCCESS, ti.state)
+ assert State.SUCCESS == ti.state
def test_localtaskjob_double_trigger(self):
dagbag = DagBag(
@@ -247,8 +249,8 @@ def test_localtaskjob_double_trigger(self):
mock_method.assert_not_called()
ti = dr.get_task_instance(task_id=task.task_id, session=session)
- self.assertEqual(ti.pid, 1)
- self.assertEqual(ti.state, State.RUNNING)
+ assert ti.pid == 1
+ assert ti.state == State.RUNNING
session.close()
@@ -290,18 +292,18 @@ def multi_return_code():
with patch.object(StandardTaskRunner, 'return_code') as mock_ret_code:
mock_ret_code.side_effect = multi_return_code
job1.run()
- self.assertEqual(mock_start.call_count, 1)
- self.assertEqual(mock_ret_code.call_count, 2)
+ assert mock_start.call_count == 1
+ assert mock_ret_code.call_count == 2
time_end = time.time()
- self.assertEqual(self.mock_base_job_sleep.call_count, 1)
- self.assertEqual(job1.state, State.SUCCESS)
+ assert self.mock_base_job_sleep.call_count == 1
+ assert job1.state == State.SUCCESS
# Consider we have patched sleep call, it should not be sleeping to
# keep up with the heart rate in other unpatched places
#
# We already make sure patched sleep call is only called once
- self.assertLess(time_end - time_start, job1.heartrate)
+ assert time_end - time_start < job1.heartrate
session.close()
def test_mark_failure_on_failure_callback(self):
@@ -312,13 +314,13 @@ def test_mark_failure_on_failure_callback(self):
data = {'called': False}
def check_failure(context):
- self.assertEqual(context['dag_run'].dag_id, 'test_mark_failure')
+ assert context['dag_run'].dag_id == 'test_mark_failure'
data['called'] = True
def task_function(ti):
print("python_callable run in pid %s", os.getpid())
with create_session() as session:
- self.assertEqual(State.RUNNING, ti.state)
+ assert State.RUNNING == ti.state
ti.log.info("Marking TI as failed 'externally'")
ti.state = State.FAILED
session.merge(ti)
@@ -355,11 +357,9 @@ def task_function(ti):
job1.run()
ti.refresh_from_db()
- self.assertEqual(ti.state, State.FAILED)
- self.assertTrue(data['called'])
- self.assertNotIn(
- 'reached_end_of_sleep', data, 'Task should not have been allowed to run to completion'
- )
+ assert ti.state == State.FAILED
+ assert data['called']
+ assert 'reached_end_of_sleep' not in data, 'Task should not have been allowed to run to completion'
@pytest.mark.quarantined
def test_mark_success_on_success_callback(self):
@@ -370,7 +370,7 @@ def test_mark_success_on_success_callback(self):
data = {'called': False}
def success_callback(context):
- self.assertEqual(context['dag_run'].dag_id, 'test_mark_success')
+ assert context['dag_run'].dag_id == 'test_mark_success'
data['called'] = True
dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
@@ -401,15 +401,15 @@ def success_callback(context):
break
time.sleep(0.1)
ti.refresh_from_db()
- self.assertEqual(State.RUNNING, ti.state)
+ assert State.RUNNING == ti.state
ti.state = State.SUCCESS
session.merge(ti)
session.commit()
job1.heartbeat_callback(session=None)
- self.assertTrue(data['called'])
+ assert data['called']
process.join(timeout=10)
- self.assertFalse(process.is_alive())
+ assert not process.is_alive()
@pytest.fixture()
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index da27016a00763..d0d565d5599df 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -297,11 +297,11 @@ def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock
dag_file_processor.manage_slas(dag=dag, session=session)
- self.assertTrue(len(mock_send_email.call_args_list), 1)
+ assert len(mock_send_email.call_args_list) == 1
send_email_to = mock_send_email.call_args_list[0][0][0]
- self.assertIn(email1, send_email_to)
- self.assertNotIn(email2, send_email_to)
+ assert email1 in send_email_to
+ assert email2 not in send_email_to
@mock.patch('airflow.jobs.scheduler_job.Stats.incr')
@mock.patch("airflow.utils.email.send_email")
@@ -566,7 +566,7 @@ def test_scheduler_job_add_new_task(self):
dr = drs[0]
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 1)
+ assert len(tis) == 1
BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test')
SerializedDagModel.write_dag(dag=dag)
@@ -580,7 +580,7 @@ def test_scheduler_job_add_new_task(self):
dr = drs[0]
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
def test_runs_respected_after_clear(self):
"""
@@ -695,7 +695,7 @@ def test_process_file_should_failure_callback(self):
dag_file_processor.process_file(dag_file, requests)
with open(callback_file.name) as callback_file2:
content = callback_file2.read()
- self.assertEqual("Callback fired", content)
+ assert "Callback fired" == content
os.remove(callback_file.name)
def test_should_mark_dummy_task_as_success(self):
@@ -726,56 +726,50 @@ def test_should_mark_dummy_task_as_success(self):
tis = session.query(TaskInstance).all()
dags = scheduler_job.dagbag.dags.values()
- self.assertEqual(['test_only_dummy_tasks'], [dag.dag_id for dag in dags])
- self.assertEqual(5, len(tis))
- self.assertEqual(
- {
- ('test_task_a', 'success'),
- ('test_task_b', None),
- ('test_task_c', 'success'),
- ('test_task_on_execute', 'scheduled'),
- ('test_task_on_success', 'scheduled'),
- },
- {(ti.task_id, ti.state) for ti in tis},
- )
+ assert ['test_only_dummy_tasks'] == [dag.dag_id for dag in dags]
+ assert 5 == len(tis)
+ assert {
+ ('test_task_a', 'success'),
+ ('test_task_b', None),
+ ('test_task_c', 'success'),
+ ('test_task_on_execute', 'scheduled'),
+ ('test_task_on_success', 'scheduled'),
+ } == {(ti.task_id, ti.state) for ti in tis}
for state, start_date, end_date, duration in [
(ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis
]:
if state == 'success':
- self.assertIsNotNone(start_date)
- self.assertIsNotNone(end_date)
- self.assertEqual(0.0, duration)
+ assert start_date is not None
+ assert end_date is not None
+ assert 0.0 == duration
else:
- self.assertIsNone(start_date)
- self.assertIsNone(end_date)
- self.assertIsNone(duration)
+ assert start_date is None
+ assert end_date is None
+ assert duration is None
scheduler_job._schedule_dag_run(dr, {}, session)
with create_session() as session:
tis = session.query(TaskInstance).all()
- self.assertEqual(5, len(tis))
- self.assertEqual(
- {
- ('test_task_a', 'success'),
- ('test_task_b', 'success'),
- ('test_task_c', 'success'),
- ('test_task_on_execute', 'scheduled'),
- ('test_task_on_success', 'scheduled'),
- },
- {(ti.task_id, ti.state) for ti in tis},
- )
+ assert 5 == len(tis)
+ assert {
+ ('test_task_a', 'success'),
+ ('test_task_b', 'success'),
+ ('test_task_c', 'success'),
+ ('test_task_on_execute', 'scheduled'),
+ ('test_task_on_success', 'scheduled'),
+ } == {(ti.task_id, ti.state) for ti in tis}
for state, start_date, end_date, duration in [
(ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis
]:
if state == 'success':
- self.assertIsNotNone(start_date)
- self.assertIsNotNone(end_date)
- self.assertEqual(0.0, duration)
+ assert start_date is not None
+ assert end_date is not None
+ assert 0.0 == duration
else:
- self.assertIsNone(start_date)
- self.assertIsNone(end_date)
- self.assertIsNone(duration)
+ assert start_date is None
+ assert end_date is None
+ assert duration is None
@pytest.mark.usefixtures("disable_load_example")
@@ -810,22 +804,22 @@ def setUpClass(cls):
def test_is_alive(self):
job = SchedulerJob(None, heartrate=10, state=State.RUNNING)
- self.assertTrue(job.is_alive())
+ assert job.is_alive()
job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=20)
- self.assertTrue(job.is_alive())
+ assert job.is_alive()
job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=31)
- self.assertFalse(job.is_alive())
+ assert not job.is_alive()
# test because .seconds was used before instead of total_seconds
# internal repr of datetime is (days, seconds)
job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(days=1)
- self.assertFalse(job.is_alive())
+ assert not job.is_alive()
job.state = State.SUCCESS
job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=10)
- self.assertFalse(job.is_alive(), "Completed jobs even with recent heartbeat should not be alive")
+ assert not job.is_alive(), "Completed jobs even with recent heartbeat should not be alive"
def run_single_scheduler_loop_with_no_dags(self, dags_folder):
"""
@@ -843,6 +837,7 @@ def run_single_scheduler_loop_with_no_dags(self, dags_folder):
scheduler.heartrate = 0
scheduler.run()
+ @pytest.mark.quarantined
def test_no_orphan_process_will_be_left(self):
empty_dir = mkdtemp()
current_process = psutil.Process()
@@ -853,7 +848,7 @@ def test_no_orphan_process_will_be_left(self):
# Remove potential noise created by previous tests.
current_children = set(current_process.children(recursive=True)) - set(old_children)
- self.assertFalse(current_children)
+ assert not current_children
@mock.patch('airflow.jobs.scheduler_job.TaskCallbackRequest')
@mock.patch('airflow.jobs.scheduler_job.Stats.incr')
@@ -888,7 +883,7 @@ def test_process_executor_events(self, mock_stats_incr, mock_task_callback):
scheduler._process_executor_events(session=session)
ti1.refresh_from_db()
- self.assertEqual(ti1.state, State.QUEUED)
+ assert ti1.state == State.QUEUED
mock_task_callback.assert_called_once_with(
full_filepath='/test_path1/',
simple_task_instance=mock.ANY,
@@ -908,7 +903,7 @@ def test_process_executor_events(self, mock_stats_incr, mock_task_callback):
scheduler._process_executor_events(session=session)
ti1.refresh_from_db()
- self.assertEqual(ti1.state, State.SUCCESS)
+ assert ti1.state == State.SUCCESS
scheduler.processor_agent.send_callback_to_execute.assert_not_called()
mock_stats_incr.assert_called_once_with('scheduler.tasks.killed_externally')
@@ -936,7 +931,7 @@ def test_process_executor_events_uses_inmemory_try_number(self):
scheduler._process_executor_events()
# Assert that the even_buffer is empty so the task was popped using right
# task instance key
- self.assertEqual(event_buffer, {})
+ assert event_buffer == {}
def test_execute_task_instances_is_paused_wont_execute(self):
dag_id = 'SchedulerJobTest.test_execute_task_instances_is_paused_wont_execute'
@@ -970,7 +965,7 @@ def test_execute_task_instances_is_paused_wont_execute(self):
scheduler._critical_section_execute_task_instances(session)
session.flush()
ti1.refresh_from_db()
- self.assertEqual(State.SCHEDULED, ti1.state)
+ assert State.SCHEDULED == ti1.state
session.rollback()
def test_execute_task_instances_no_dagrun_task_will_execute(self):
@@ -1008,7 +1003,7 @@ def test_execute_task_instances_no_dagrun_task_will_execute(self):
scheduler._critical_section_execute_task_instances(session)
session.flush()
ti1.refresh_from_db()
- self.assertEqual(State.QUEUED, ti1.state)
+ assert State.QUEUED == ti1.state
session.rollback()
def test_execute_task_instances_backfill_tasks_wont_execute(self):
@@ -1044,12 +1039,12 @@ def test_execute_task_instances_backfill_tasks_wont_execute(self):
session.merge(dr1)
session.flush()
- self.assertTrue(dr1.is_backfill)
+ assert dr1.is_backfill
scheduler._critical_section_execute_task_instances(session)
session.flush()
ti1.refresh_from_db()
- self.assertEqual(State.SCHEDULED, ti1.state)
+ assert State.SCHEDULED == ti1.state
session.rollback()
def test_find_executable_task_instances_backfill_nodagrun(self):
@@ -1096,10 +1091,10 @@ def test_find_executable_task_instances_backfill_nodagrun(self):
res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
- self.assertEqual(2, len(res))
+ assert 2 == len(res)
res_keys = map(lambda x: x.key, res)
- self.assertIn(ti_no_dagrun.key, res_keys)
- self.assertIn(ti_with_dagrun.key, res_keys)
+ assert ti_no_dagrun.key in res_keys
+ assert ti_with_dagrun.key in res_keys
session.rollback()
def test_find_executable_task_instances_pool(self):
@@ -1149,13 +1144,13 @@ def test_find_executable_task_instances_pool(self):
res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
session.flush()
- self.assertEqual(3, len(res))
+ assert 3 == len(res)
res_keys = []
for ti in res:
res_keys.append(ti.key)
- self.assertIn(tis[0].key, res_keys)
- self.assertIn(tis[1].key, res_keys)
- self.assertIn(tis[3].key, res_keys)
+ assert tis[0].key in res_keys
+ assert tis[1].key in res_keys
+ assert tis[3].key in res_keys
session.rollback()
def test_find_executable_task_instances_in_default_pool(self):
@@ -1199,7 +1194,7 @@ def test_find_executable_task_instances_in_default_pool(self):
# Two tasks w/o pool up for execution and our default pool size is 1
res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
- self.assertEqual(1, len(res))
+ assert 1 == len(res)
ti2.state = State.RUNNING
session.merge(ti2)
@@ -1207,7 +1202,7 @@ def test_find_executable_task_instances_in_default_pool(self):
# One task w/o pool up for execution and one task task running
res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
- self.assertEqual(0, len(res))
+ assert 0 == len(res)
session.rollback()
session.close()
@@ -1242,7 +1237,7 @@ def test_nonexistent_pool(self):
res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
session.flush()
- self.assertEqual(0, len(res))
+ assert 0 == len(res)
session.rollback()
def test_find_executable_task_instances_none(self):
@@ -1269,7 +1264,7 @@ def test_find_executable_task_instances_none(self):
)
session.flush()
- self.assertEqual(0, len(scheduler._executable_task_instances_to_queued(max_tis=32, session=session)))
+ assert 0 == len(scheduler._executable_task_instances_to_queued(max_tis=32, session=session))
session.rollback()
def test_find_executable_task_instances_concurrency(self):
@@ -1319,9 +1314,9 @@ def test_find_executable_task_instances_concurrency(self):
res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
- self.assertEqual(1, len(res))
+ assert 1 == len(res)
res_keys = map(lambda x: x.key, res)
- self.assertIn(ti2.key, res_keys)
+ assert ti2.key in res_keys
ti2.state = State.RUNNING
session.merge(ti2)
@@ -1329,7 +1324,7 @@ def test_find_executable_task_instances_concurrency(self):
res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
- self.assertEqual(0, len(res))
+ assert 0 == len(res)
session.rollback()
def test_find_executable_task_instances_concurrency_queued(self):
@@ -1370,8 +1365,8 @@ def test_find_executable_task_instances_concurrency_queued(self):
res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
- self.assertEqual(1, len(res))
- self.assertEqual(res[0].key, ti3.key)
+ assert 1 == len(res)
+ assert res[0].key == ti3.key
session.rollback()
# TODO: This is a hack, I think I need to just remove the setting and have it on always
@@ -1417,7 +1412,7 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab
res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
- self.assertEqual(2, len(res))
+ assert 2 == len(res)
ti1_1.state = State.RUNNING
ti2.state = State.RUNNING
@@ -1430,7 +1425,7 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab
res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
- self.assertEqual(1, len(res))
+ assert 1 == len(res)
ti1_2.state = State.RUNNING
ti1_3 = TaskInstance(task1, dr3.execution_date)
@@ -1441,7 +1436,7 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab
res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
- self.assertEqual(0, len(res))
+ assert 0 == len(res)
ti1_1.state = State.SCHEDULED
ti1_2.state = State.SCHEDULED
@@ -1453,7 +1448,7 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab
res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
- self.assertEqual(2, len(res))
+ assert 2 == len(res)
ti1_1.state = State.RUNNING
ti1_2.state = State.SCHEDULED
@@ -1465,7 +1460,7 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab
res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
- self.assertEqual(1, len(res))
+ assert 1 == len(res)
session.rollback()
def test_change_state_for_executable_task_instances_no_tis_with_state(self):
@@ -1510,7 +1505,7 @@ def test_change_state_for_executable_task_instances_no_tis_with_state(self):
session.flush()
res = scheduler._executable_task_instances_to_queued(max_tis=100, session=session)
- self.assertEqual(0, len(res))
+ assert 0 == len(res)
session.rollback()
@@ -1587,10 +1582,8 @@ def test_critical_section_execute_task_instances(self):
session.merge(ti2)
session.flush()
- self.assertEqual(State.RUNNING, dr1.state)
- self.assertEqual(
- 2, DAG.get_num_task_instances(dag_id, dag.task_ids, states=[State.RUNNING], session=session)
- )
+ assert State.RUNNING == dr1.state
+ assert 2 == DAG.get_num_task_instances(dag_id, dag.task_ids, states=[State.RUNNING], session=session)
# create second dag run
dr2 = dag.create_dagrun(
@@ -1609,7 +1602,7 @@ def test_critical_section_execute_task_instances(self):
session.merge(ti4)
session.flush()
- self.assertEqual(State.RUNNING, dr2.state)
+ assert State.RUNNING == dr2.state
res = scheduler._critical_section_execute_task_instances(session)
@@ -1618,16 +1611,13 @@ def test_critical_section_execute_task_instances(self):
ti2.refresh_from_db()
ti3.refresh_from_db()
ti4.refresh_from_db()
- self.assertEqual(
- 3,
- DAG.get_num_task_instances(
- dag_id, dag.task_ids, states=[State.RUNNING, State.QUEUED], session=session
- ),
+ assert 3 == DAG.get_num_task_instances(
+ dag_id, dag.task_ids, states=[State.RUNNING, State.QUEUED], session=session
)
- self.assertEqual(State.RUNNING, ti1.state)
- self.assertEqual(State.RUNNING, ti2.state)
- self.assertCountEqual([State.QUEUED, State.SCHEDULED], [ti3.state, ti4.state])
- self.assertEqual(1, res)
+ assert State.RUNNING == ti1.state
+ assert State.RUNNING == ti2.state
+ assert {State.QUEUED, State.SCHEDULED} == {ti3.state, ti4.state}
+ assert 1 == res
def test_execute_task_instances_limit(self):
dag_id = 'SchedulerJobTest.test_execute_task_instances_limit'
@@ -1674,7 +1664,7 @@ def test_execute_task_instances_limit(self):
session.flush()
scheduler.max_tis_per_query = 2
res = scheduler._critical_section_execute_task_instances(session)
- self.assertEqual(2, res)
+ assert 2 == res
scheduler.max_tis_per_query = 8
with mock.patch.object(
@@ -1682,14 +1672,14 @@ def test_execute_task_instances_limit(self):
) as mock_slots:
mock_slots.return_value = 2
# Check that we don't "overfill" the executor
- self.assertEqual(2, res)
+ assert 2 == res
res = scheduler._critical_section_execute_task_instances(session)
res = scheduler._critical_section_execute_task_instances(session)
- self.assertEqual(4, res)
+ assert 4 == res
for ti in tis:
ti.refresh_from_db()
- self.assertEqual(State.QUEUED, ti.state)
+ assert State.QUEUED == ti.state
def test_change_state_for_tis_without_dagrun(self):
dag1 = DAG(dag_id='test_change_state_for_tis_without_dagrun', start_date=DEFAULT_DATE)
@@ -1754,21 +1744,21 @@ def test_change_state_for_tis_without_dagrun(self):
ti1a = dr1.get_task_instance(task_id='dummy', session=session)
ti1a.refresh_from_db(session=session)
- self.assertEqual(ti1a.state, State.SCHEDULED)
+ assert ti1a.state == State.SCHEDULED
ti1b = dr1.get_task_instance(task_id='dummy_b', session=session)
ti1b.refresh_from_db(session=session)
- self.assertEqual(ti1b.state, State.SUCCESS)
+ assert ti1b.state == State.SUCCESS
ti2 = dr2.get_task_instance(task_id='dummy', session=session)
ti2.refresh_from_db(session=session)
- self.assertEqual(ti2.state, State.SCHEDULED)
+ assert ti2.state == State.SCHEDULED
ti3.refresh_from_db(session=session)
- self.assertEqual(ti3.state, State.NONE)
- self.assertIsNotNone(ti3.start_date)
- self.assertIsNone(ti3.end_date)
- self.assertIsNone(ti3.duration)
+ assert ti3.state == State.NONE
+ assert ti3.start_date is not None
+ assert ti3.end_date is None
+ assert ti3.duration is None
dr1.refresh_from_db(session=session)
dr1.state = State.FAILED
@@ -1784,15 +1774,15 @@ def test_change_state_for_tis_without_dagrun(self):
# Clear the session objects
session.expunge_all()
ti1a.refresh_from_db(session=session)
- self.assertEqual(ti1a.state, State.NONE)
+ assert ti1a.state == State.NONE
# don't touch ti1b
ti1b.refresh_from_db(session=session)
- self.assertEqual(ti1b.state, State.SUCCESS)
+ assert ti1b.state == State.SUCCESS
# don't touch ti2
ti2.refresh_from_db(session=session)
- self.assertEqual(ti2.state, State.SCHEDULED)
+ assert ti2.state == State.SCHEDULED
def test_change_state_for_tasks_failed_to_execute(self):
dag = DAG(dag_id='dag_id', start_date=DEFAULT_DATE)
@@ -1823,7 +1813,7 @@ def test_change_state_for_tasks_failed_to_execute(self):
scheduler_job._change_state_for_tasks_failed_to_execute() # pylint: disable=no-value-for-parameter
ti.refresh_from_db()
- self.assertEqual(State.SCHEDULED, ti.state)
+ assert State.SCHEDULED == ti.state
# Tasks failed to execute with RUNNING state will not be set to SCHEDULED state.
session.query(TaskInstance).delete()
@@ -1836,7 +1826,7 @@ def test_change_state_for_tasks_failed_to_execute(self):
scheduler_job._change_state_for_tasks_failed_to_execute() # pylint: disable=no-value-for-parameter
ti.refresh_from_db()
- self.assertEqual(State.RUNNING, ti.state)
+ assert State.RUNNING == ti.state
def test_adopt_or_reset_orphaned_tasks(self):
session = settings.Session()
@@ -1879,10 +1869,10 @@ def test_adopt_or_reset_orphaned_tasks(self):
scheduler.adopt_or_reset_orphaned_tasks()
ti = dr.get_task_instance(task_id=op1.task_id, session=session)
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
ti2 = dr2.get_task_instance(task_id=op1.task_id, session=session)
- self.assertEqual(ti2.state, State.SCHEDULED, "Tasks run by Backfill Jobs should not be reset")
+ assert ti2.state == State.SCHEDULED, "Tasks run by Backfill Jobs should not be reset"
@parameterized.expand(
[
@@ -1938,12 +1928,12 @@ def test_scheduler_loop_should_change_state_for_tis_without_dagrun(
scheduler._run_scheduler_loop()
ti = dr.get_task_instance(task_id=op1.task_id, session=session)
- self.assertEqual(ti.state, expected_task_state)
- self.assertIsNotNone(ti.start_date)
+ assert ti.state == expected_task_state
+ assert ti.start_date is not None
if expected_task_state in State.finished:
- self.assertIsNotNone(ti.end_date)
- self.assertEqual(ti.start_date, ti.end_date)
- self.assertIsNotNone(ti.duration)
+ assert ti.end_date is not None
+ assert ti.start_date == ti.end_date
+ assert ti.duration is not None
def test_dagrun_timeout_verify_max_active_runs(self):
"""
@@ -2185,7 +2175,7 @@ def test_do_not_schedule_removed_task(self):
state=State.RUNNING,
session=session,
)
- self.assertIsNotNone(dr)
+ assert dr is not None
# Re-create the DAG, but remove the task
dag = DAG(dag_id='test_scheduler_do_not_schedule_removed_task', start_date=DEFAULT_DATE)
@@ -2193,7 +2183,7 @@ def test_do_not_schedule_removed_task(self):
scheduler = SchedulerJob(subdir=os.devnull)
res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
- self.assertEqual([], res)
+ assert [] == res
session.rollback()
session.close()
@@ -2250,14 +2240,14 @@ def evaluate_dagrun(
task = dag.get_task(task_id)
ti = TaskInstance(task, ex_date)
ti.refresh_from_db()
- self.assertEqual(ti.state, expected_state)
+ assert ti.state == expected_state
# load dagrun
dr = DagRun.find(dag_id=dag_id, execution_date=ex_date)
dr = dr[0]
dr.dag = dag
- self.assertEqual(dr.state, dagrun_state)
+ assert dr.state == dagrun_state
def test_dagrun_fail(self):
"""
@@ -2313,7 +2303,7 @@ def test_dagrun_root_fail_unfinished(self):
)
self.null_exec.mock_task_fail(dag_id, 'test_dagrun_fail', DEFAULT_DATE)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
dag.run(start_date=dr.execution_date, end_date=dr.execution_date, executor=self.null_exec)
# Mark the successful task as never having run since we want to see if the
@@ -2323,7 +2313,7 @@ def test_dagrun_root_fail_unfinished(self):
ti.state = State.NONE
session.commit()
dr.update_state()
- self.assertEqual(dr.state, State.RUNNING)
+ assert dr.state == State.RUNNING
def test_dagrun_root_after_dagrun_unfinished(self):
"""
@@ -2341,8 +2331,8 @@ def test_dagrun_root_after_dagrun_unfinished(self):
first_run = DagRun.find(dag_id=dag_id, execution_date=DEFAULT_DATE)[0]
ti_ids = [(ti.task_id, ti.state) for ti in first_run.get_task_instances()]
- self.assertEqual(ti_ids, [('current', State.SUCCESS)])
- self.assertIn(first_run.state, [State.SUCCESS, State.RUNNING])
+ assert ti_ids == [('current', State.SUCCESS)]
+ assert first_run.state in [State.SUCCESS, State.RUNNING]
def test_dagrun_deadlock_ignore_depends_on_past_advance_ex_date(self):
"""
@@ -2388,7 +2378,7 @@ def test_scheduler_start_date(self):
dag_id = 'test_start_date_scheduling'
dag = self.dagbag.get_dag(dag_id)
dag.clear()
- self.assertGreater(dag.start_date, datetime.datetime.now(timezone.utc))
+ assert dag.start_date > datetime.datetime.now(timezone.utc)
# Deactivate other dags in this file
other_dag = self.dagbag.get_dag('test_task_start_date_scheduling')
@@ -2399,9 +2389,9 @@ def test_scheduler_start_date(self):
scheduler.run()
# zero tasks ran
- self.assertEqual(len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 0)
+ assert len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()) == 0
session.commit()
- self.assertListEqual([], self.null_exec.sorted_tasks)
+ assert [] == self.null_exec.sorted_tasks
# previously, running this backfill would kick off the Scheduler
# because it would take the most recent run and start from there
@@ -2412,22 +2402,19 @@ def test_scheduler_start_date(self):
backfill.run()
# one task ran
- self.assertEqual(len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 1)
- self.assertListEqual(
- [
- (TaskInstanceKey(dag.dag_id, 'dummy', DEFAULT_DATE, 1), (State.SUCCESS, None)),
- ],
- bf_exec.sorted_tasks,
- )
+ assert len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()) == 1
+ assert [
+ (TaskInstanceKey(dag.dag_id, 'dummy', DEFAULT_DATE, 1), (State.SUCCESS, None)),
+ ] == bf_exec.sorted_tasks
session.commit()
scheduler = SchedulerJob(dag.fileloc, executor=self.null_exec, num_runs=1)
scheduler.run()
# still one task
- self.assertEqual(len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 1)
+ assert len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()) == 1
session.commit()
- self.assertListEqual([], self.null_exec.sorted_tasks)
+ assert [] == self.null_exec.sorted_tasks
@pytest.mark.quarantined
def test_scheduler_task_start_date(self):
@@ -2455,10 +2442,10 @@ def test_scheduler_task_start_date(self):
tiq = session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id)
ti1s = tiq.filter(TaskInstance.task_id == 'dummy1').all()
ti2s = tiq.filter(TaskInstance.task_id == 'dummy2').all()
- self.assertEqual(len(ti1s), 0)
- self.assertEqual(len(ti2s), 2)
+ assert len(ti1s) == 0
+ assert len(ti2s) == 2
for task in ti2s:
- self.assertEqual(task.state, State.SUCCESS)
+ assert task.state == State.SUCCESS
def test_scheduler_multiprocessing(self):
"""
@@ -2479,7 +2466,7 @@ def test_scheduler_multiprocessing(self):
# zero tasks ran
dag_id = 'test_start_date_scheduling'
session = settings.Session()
- self.assertEqual(len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()), 0)
+ assert len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()) == 0
@conf_vars({("core", "mp_start_method"): "spawn"})
def test_scheduler_multiprocessing_with_spawn_method(self):
@@ -2503,7 +2490,7 @@ def test_scheduler_multiprocessing_with_spawn_method(self):
# zero tasks ran
dag_id = 'test_start_date_scheduling'
with create_session() as session:
- self.assertEqual(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).count(), 0)
+ assert session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).count() == 0
def test_scheduler_verify_pool_full(self):
"""
@@ -2553,7 +2540,7 @@ def test_scheduler_verify_pool_full(self):
task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
- self.assertEqual(len(task_instances_list), 1)
+ assert len(task_instances_list) == 1
def test_scheduler_verify_pool_full_2_slots_per_task(self):
"""
@@ -2604,7 +2591,7 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self):
task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
# As tasks require 2 slots, only 3 can fit into 6 available
- self.assertEqual(len(task_instances_list), 3)
+ assert len(task_instances_list) == 3
def test_scheduler_verify_priority_and_slots(self):
"""
@@ -2674,28 +2661,28 @@ def test_scheduler_verify_priority_and_slots(self):
task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
# Only second and third
- self.assertEqual(len(task_instances_list), 2)
+ assert len(task_instances_list) == 2
ti0 = (
session.query(TaskInstance)
.filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t0')
.first()
)
- self.assertEqual(ti0.state, State.SCHEDULED)
+ assert ti0.state == State.SCHEDULED
ti1 = (
session.query(TaskInstance)
.filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t1')
.first()
)
- self.assertEqual(ti1.state, State.QUEUED)
+ assert ti1.state == State.QUEUED
ti2 = (
session.query(TaskInstance)
.filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t2')
.first()
)
- self.assertEqual(ti2.state, State.QUEUED)
+ assert ti2.state == State.QUEUED
def test_verify_integrity_if_dag_not_changed(self):
# CleanUp
@@ -2872,15 +2859,15 @@ def run_with_error(ti, ignore_ti_state=False):
except AirflowException:
pass
- self.assertEqual(ti.try_number, 1)
+ assert ti.try_number == 1
# At this point, scheduler has tried to schedule the task once and
# heartbeated the executor once, which moved the state of the task from
# SCHEDULED to QUEUED and then to SCHEDULED, to fail the task execution
# we need to ignore the TaskInstance state as SCHEDULED is not a valid state to start
# executing task.
run_with_error(ti, ignore_ti_state=True)
- self.assertEqual(ti.state, State.UP_FOR_RETRY)
- self.assertEqual(ti.try_number, 2)
+ assert ti.state == State.UP_FOR_RETRY
+ assert ti.try_number == 2
with create_session() as session:
ti.refresh_from_db(lock_for_update=True, session=session)
@@ -2891,7 +2878,7 @@ def run_with_error(ti, ignore_ti_state=False):
executor.do_update = True
do_schedule() # pylint: disable=no-value-for-parameter
ti.refresh_from_db()
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
@pytest.mark.quarantined
def test_retry_handling_job(self):
@@ -2915,8 +2902,8 @@ def test_retry_handling_job(self):
)
# make sure the counter has increased
- self.assertEqual(ti.try_number, 2)
- self.assertEqual(ti.state, State.UP_FOR_RETRY)
+ assert ti.try_number == 2
+ assert ti.state == State.UP_FOR_RETRY
def test_dag_get_active_runs(self):
"""
@@ -2955,7 +2942,7 @@ def test_dag_get_active_runs(self):
)
# We had better get a dag run
- self.assertIsNotNone(dr)
+ assert dr is not None
execution_date = dr.execution_date
@@ -2966,7 +2953,7 @@ def test_dag_get_active_runs(self):
except Exception: # pylint: disable=broad-except
running_date = 'Except'
- self.assertEqual(execution_date, running_date, 'Running Date must match Execution Date')
+ assert execution_date == running_date, 'Running Date must match Execution Date'
@conf_vars({("core", "dagbag_import_error_tracebacks"): "False"})
def test_add_unparseable_file_before_sched_start_creates_import_error(self):
@@ -2983,10 +2970,10 @@ def test_add_unparseable_file_before_sched_start_creates_import_error(self):
with create_session() as session:
import_errors = session.query(errors.ImportError).all()
- self.assertEqual(len(import_errors), 1)
+ assert len(import_errors) == 1
import_error = import_errors[0]
- self.assertEqual(import_error.filename, unparseable_filename)
- self.assertEqual(import_error.stacktrace, f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)")
+ assert import_error.filename == unparseable_filename
+ assert import_error.stacktrace == f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)"
@conf_vars({("core", "dagbag_import_error_tracebacks"): "False"})
def test_add_unparseable_file_after_sched_start_creates_import_error(self):
@@ -3011,10 +2998,10 @@ def test_add_unparseable_file_after_sched_start_creates_import_error(self):
with create_session() as session:
import_errors = session.query(errors.ImportError).all()
- self.assertEqual(len(import_errors), 1)
+ assert len(import_errors) == 1
import_error = import_errors[0]
- self.assertEqual(import_error.filename, unparseable_filename)
- self.assertEqual(import_error.stacktrace, f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)")
+ assert import_error.filename == unparseable_filename
+ assert import_error.stacktrace == f"invalid syntax ({TEMP_DAG_FILENAME}, line 1)"
def test_no_import_errors_with_parseable_dag(self):
try:
@@ -3030,7 +3017,7 @@ def test_no_import_errors_with_parseable_dag(self):
with create_session() as session:
import_errors = session.query(errors.ImportError).all()
- self.assertEqual(len(import_errors), 0)
+ assert len(import_errors) == 0
@conf_vars({("core", "dagbag_import_error_tracebacks"): "False"})
def test_new_import_error_replaces_old(self):
@@ -3055,10 +3042,10 @@ def test_new_import_error_replaces_old(self):
session = settings.Session()
import_errors = session.query(errors.ImportError).all()
- self.assertEqual(len(import_errors), 1)
+ assert len(import_errors) == 1
import_error = import_errors[0]
- self.assertEqual(import_error.filename, unparseable_filename)
- self.assertEqual(import_error.stacktrace, f"invalid syntax ({TEMP_DAG_FILENAME}, line 2)")
+ assert import_error.filename == unparseable_filename
+ assert import_error.stacktrace == f"invalid syntax ({TEMP_DAG_FILENAME}, line 2)"
def test_remove_error_clears_import_error(self):
try:
@@ -3080,7 +3067,7 @@ def test_remove_error_clears_import_error(self):
session = settings.Session()
import_errors = session.query(errors.ImportError).all()
- self.assertEqual(len(import_errors), 0)
+ assert len(import_errors) == 0
def test_remove_file_clears_import_error(self):
try:
@@ -3100,7 +3087,7 @@ def test_remove_file_clears_import_error(self):
with create_session() as session:
import_errors = session.query(errors.ImportError).all()
- self.assertEqual(len(import_errors), 0)
+ assert len(import_errors) == 0
def test_import_error_tracebacks(self):
dags_folder = mkdtemp()
@@ -3115,9 +3102,9 @@ def test_import_error_tracebacks(self):
with create_session() as session:
import_errors = session.query(errors.ImportError).all()
- self.assertEqual(len(import_errors), 1)
+ assert len(import_errors) == 1
import_error = import_errors[0]
- self.assertEqual(import_error.filename, unparseable_filename)
+ assert import_error.filename == unparseable_filename
expected_stacktrace = (
"Traceback (most recent call last):\n"
' File "{}", line 3, in \n'
@@ -3126,8 +3113,8 @@ def test_import_error_tracebacks(self):
" return airflow_DAG\n"
"NameError: name 'airflow_DAG' is not defined\n"
)
- self.assertEqual(
- import_error.stacktrace, expected_stacktrace.format(unparseable_filename, unparseable_filename)
+ assert import_error.stacktrace == expected_stacktrace.format(
+ unparseable_filename, unparseable_filename
)
@conf_vars({("core", "dagbag_import_error_traceback_depth"): "1"})
@@ -3144,16 +3131,16 @@ def test_import_error_traceback_depth(self):
with create_session() as session:
import_errors = session.query(errors.ImportError).all()
- self.assertEqual(len(import_errors), 1)
+ assert len(import_errors) == 1
import_error = import_errors[0]
- self.assertEqual(import_error.filename, unparseable_filename)
+ assert import_error.filename == unparseable_filename
expected_stacktrace = (
"Traceback (most recent call last):\n"
' File "{}", line 2, in something\n'
" return airflow_DAG\n"
"NameError: name 'airflow_DAG' is not defined\n"
)
- self.assertEqual(import_error.stacktrace, expected_stacktrace.format(unparseable_filename))
+ assert import_error.stacktrace == expected_stacktrace.format(unparseable_filename)
def test_import_error_tracebacks_zip(self):
dags_folder = mkdtemp()
@@ -3169,9 +3156,9 @@ def test_import_error_tracebacks_zip(self):
with create_session() as session:
import_errors = session.query(errors.ImportError).all()
- self.assertEqual(len(import_errors), 1)
+ assert len(import_errors) == 1
import_error = import_errors[0]
- self.assertEqual(import_error.filename, invalid_zip_filename)
+ assert import_error.filename == invalid_zip_filename
expected_stacktrace = (
"Traceback (most recent call last):\n"
' File "{}", line 3, in \n'
@@ -3180,8 +3167,8 @@ def test_import_error_tracebacks_zip(self):
" return airflow_DAG\n"
"NameError: name 'airflow_DAG' is not defined\n"
)
- self.assertEqual(
- import_error.stacktrace, expected_stacktrace.format(invalid_dag_filename, invalid_dag_filename)
+ assert import_error.stacktrace == expected_stacktrace.format(
+ invalid_dag_filename, invalid_dag_filename
)
@conf_vars({("core", "dagbag_import_error_traceback_depth"): "1"})
@@ -3199,16 +3186,16 @@ def test_import_error_tracebacks_zip_depth(self):
with create_session() as session:
import_errors = session.query(errors.ImportError).all()
- self.assertEqual(len(import_errors), 1)
+ assert len(import_errors) == 1
import_error = import_errors[0]
- self.assertEqual(import_error.filename, invalid_zip_filename)
+ assert import_error.filename == invalid_zip_filename
expected_stacktrace = (
"Traceback (most recent call last):\n"
' File "{}", line 2, in something\n'
" return airflow_DAG\n"
"NameError: name 'airflow_DAG' is not defined\n"
)
- self.assertEqual(import_error.stacktrace, expected_stacktrace.format(invalid_dag_filename))
+ assert import_error.stacktrace == expected_stacktrace.format(invalid_dag_filename)
def test_list_py_file_paths(self):
"""
@@ -3231,7 +3218,7 @@ def test_list_py_file_paths(self):
expected_files.add(f'{root}/{file_name}')
for file_path in list_py_file_paths(TEST_DAG_FOLDER, include_examples=False):
detected_files.add(file_path)
- self.assertEqual(detected_files, expected_files)
+ assert detected_files == expected_files
ignored_files = {
'helper.py',
@@ -3245,7 +3232,7 @@ def test_list_py_file_paths(self):
detected_files.clear()
for file_path in list_py_file_paths(TEST_DAG_FOLDER, include_examples=True):
detected_files.add(file_path)
- self.assertEqual(detected_files, expected_files)
+ assert detected_files == expected_files
smart_sensor_dag_folder = airflow.smart_sensor_dags.__path__[0]
for root, _, files in os.walk(smart_sensor_dag_folder):
@@ -3259,13 +3246,13 @@ def test_list_py_file_paths(self):
TEST_DAG_FOLDER, include_examples=True, include_smart_sensor=True
):
detected_files.add(file_path)
- self.assertEqual(detected_files, expected_files)
+ assert detected_files == expected_files
def test_adopt_or_reset_orphaned_tasks_nothing(self):
"""Try with nothing. """
scheduler = SchedulerJob()
session = settings.Session()
- self.assertEqual(0, scheduler.adopt_or_reset_orphaned_tasks(session=session))
+ assert 0 == scheduler.adopt_or_reset_orphaned_tasks(session=session)
def test_adopt_or_reset_orphaned_tasks_external_triggered_dag(self):
dag_id = 'test_reset_orphaned_tasks_external_triggered_dag'
@@ -3291,7 +3278,7 @@ def test_adopt_or_reset_orphaned_tasks_external_triggered_dag(self):
session.commit()
num_reset_tis = scheduler.adopt_or_reset_orphaned_tasks(session=session)
- self.assertEqual(1, num_reset_tis)
+ assert 1 == num_reset_tis
def test_adopt_or_reset_orphaned_tasks_backfill_dag(self):
dag_id = 'test_adopt_or_reset_orphaned_tasks_backfill_dag'
@@ -3317,8 +3304,8 @@ def test_adopt_or_reset_orphaned_tasks_backfill_dag(self):
session.merge(dr1)
session.flush()
- self.assertTrue(dr1.is_backfill)
- self.assertEqual(0, scheduler.adopt_or_reset_orphaned_tasks(session=session))
+ assert dr1.is_backfill
+ assert 0 == scheduler.adopt_or_reset_orphaned_tasks(session=session)
session.rollback()
def test_reset_orphaned_tasks_nonexistent_dagrun(self):
@@ -3340,7 +3327,7 @@ def test_reset_orphaned_tasks_nonexistent_dagrun(self):
session.merge(ti)
session.flush()
- self.assertEqual(0, scheduler.adopt_or_reset_orphaned_tasks(session=session))
+ assert 0 == scheduler.adopt_or_reset_orphaned_tasks(session=session)
session.rollback()
def test_reset_orphaned_tasks_no_orphans(self):
@@ -3368,9 +3355,9 @@ def test_reset_orphaned_tasks_no_orphans(self):
session.merge(tis[0])
session.flush()
- self.assertEqual(0, scheduler.adopt_or_reset_orphaned_tasks(session=session))
+ assert 0 == scheduler.adopt_or_reset_orphaned_tasks(session=session)
tis[0].refresh_from_db()
- self.assertEqual(State.RUNNING, tis[0].state)
+ assert State.RUNNING == tis[0].state
def test_reset_orphaned_tasks_non_running_dagruns(self):
"""Ensure orphaned tasks with non-running dagruns are not reset."""
@@ -3392,14 +3379,14 @@ def test_reset_orphaned_tasks_non_running_dagruns(self):
session=session,
)
tis = dr1.get_task_instances(session=session)
- self.assertEqual(1, len(tis))
+ assert 1 == len(tis)
tis[0].state = State.SCHEDULED
tis[0].queued_by_job_id = scheduler.id
session.merge(dr1)
session.merge(tis[0])
session.flush()
- self.assertEqual(0, scheduler.adopt_or_reset_orphaned_tasks(session=session))
+ assert 0 == scheduler.adopt_or_reset_orphaned_tasks(session=session)
session.rollback()
def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self):
@@ -3442,12 +3429,12 @@ def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self):
num_reset_tis = scheduler_job.adopt_or_reset_orphaned_tasks(session=session)
- self.assertEqual(1, num_reset_tis)
+ assert 1 == num_reset_tis
session.refresh(ti1)
- self.assertEqual(None, ti1.state)
+ assert ti1.state is None
session.refresh(ti2)
- self.assertEqual(State.SCHEDULED, ti2.state)
+ assert State.SCHEDULED == ti2.state
session.rollback()
def test_send_sla_callbacks_to_processor_sla_disabled(self):
diff --git a/tests/kubernetes/models/test_secret.py b/tests/kubernetes/models/test_secret.py
index 183a9ef9bd381..c4f0ed834d51b 100644
--- a/tests/kubernetes/models/test_secret.py
+++ b/tests/kubernetes/models/test_secret.py
@@ -29,32 +29,24 @@
class TestSecret(unittest.TestCase):
def test_to_env_secret(self):
secret = Secret('env', 'name', 'secret', 'key')
- self.assertEqual(
- secret.to_env_secret(),
- k8s.V1EnvVar(
- name='NAME',
- value_from=k8s.V1EnvVarSource(
- secret_key_ref=k8s.V1SecretKeySelector(name='secret', key='key')
- ),
- ),
+ assert secret.to_env_secret() == k8s.V1EnvVar(
+ name='NAME',
+ value_from=k8s.V1EnvVarSource(secret_key_ref=k8s.V1SecretKeySelector(name='secret', key='key')),
)
def test_to_env_from_secret(self):
secret = Secret('env', None, 'secret')
- self.assertEqual(
- secret.to_env_from_secret(), k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name='secret'))
+ assert secret.to_env_from_secret() == k8s.V1EnvFromSource(
+ secret_ref=k8s.V1SecretEnvSource(name='secret')
)
@mock.patch('uuid.uuid4')
def test_to_volume_secret(self, mock_uuid):
mock_uuid.return_value = '0'
secret = Secret('volume', '/etc/foo', 'secret_b')
- self.assertEqual(
- secret.to_volume_secret(),
- (
- k8s.V1Volume(name='secretvol0', secret=k8s.V1SecretVolumeSource(secret_name='secret_b')),
- k8s.V1VolumeMount(mount_path='/etc/foo', name='secretvol0', read_only=True),
- ),
+ assert secret.to_volume_secret() == (
+ k8s.V1Volume(name='secretvol0', secret=k8s.V1SecretVolumeSource(secret_name='secret_b')),
+ k8s.V1VolumeMount(mount_path='/etc/foo', name='secretvol0', read_only=True),
)
@mock.patch('uuid.uuid4')
@@ -62,14 +54,11 @@ def test_only_mount_sub_secret(self, mock_uuid):
mock_uuid.return_value = '0'
items = [k8s.V1KeyToPath(key="my-username", path="/extra/path")]
secret = Secret('volume', '/etc/foo', 'secret_b', items=items)
- self.assertEqual(
- secret.to_volume_secret(),
- (
- k8s.V1Volume(
- name='secretvol0', secret=k8s.V1SecretVolumeSource(secret_name='secret_b', items=items)
- ),
- k8s.V1VolumeMount(mount_path='/etc/foo', name='secretvol0', read_only=True),
+ assert secret.to_volume_secret() == (
+ k8s.V1Volume(
+ name='secretvol0', secret=k8s.V1SecretVolumeSource(secret_name='secret_b', items=items)
),
+ k8s.V1VolumeMount(mount_path='/etc/foo', name='secretvol0', read_only=True),
)
@mock.patch('uuid.uuid4')
@@ -89,60 +78,57 @@ def test_attach_to_pod(self, mock_uuid):
k8s_client = ApiClient()
pod = append_to_pod(pod, secrets)
result = k8s_client.sanitize_for_serialization(pod)
- self.assertEqual(
- result,
- {
- 'apiVersion': 'v1',
- 'kind': 'Pod',
- 'metadata': {
- 'labels': {'app': 'myapp'},
- 'name': 'myapp-pod-cf4a56d281014217b0272af6216feb48',
- 'namespace': 'default',
- },
- 'spec': {
- 'containers': [
- {
- 'command': ['sh', '-c', 'echo Hello Kubernetes!'],
- 'env': [
- {'name': 'ENVIRONMENT', 'value': 'prod'},
- {'name': 'LOG_LEVEL', 'value': 'warning'},
- {
- 'name': 'TARGET',
- 'valueFrom': {'secretKeyRef': {'key': 'source_b', 'name': 'secret_b'}},
- },
- ],
- 'envFrom': [
- {'configMapRef': {'name': 'configmap_a'}},
- {'secretRef': {'name': 'secret_a'}},
- ],
- 'image': 'busybox',
- 'name': 'base',
- 'ports': [{'containerPort': 1234, 'name': 'foo'}],
- 'resources': {'limits': {'memory': '200Mi'}, 'requests': {'memory': '100Mi'}},
- 'volumeMounts': [
- {'mountPath': '/airflow/xcom', 'name': 'xcom'},
- {
- 'mountPath': '/etc/foo',
- 'name': 'secretvol' + str(static_uuid),
- 'readOnly': True,
- },
- ],
- },
- {
- 'command': ['sh', '-c', 'trap "exit 0" INT; while true; do sleep 30; done;'],
- 'image': 'alpine',
- 'name': 'airflow-xcom-sidecar',
- 'resources': {'requests': {'cpu': '1m'}},
- 'volumeMounts': [{'mountPath': '/airflow/xcom', 'name': 'xcom'}],
- },
- ],
- 'hostNetwork': True,
- 'imagePullSecrets': [{'name': 'pull_secret_a'}, {'name': 'pull_secret_b'}],
- 'securityContext': {'fsGroup': 2000, 'runAsUser': 1000},
- 'volumes': [
- {'emptyDir': {}, 'name': 'xcom'},
- {'name': 'secretvol' + str(static_uuid), 'secret': {'secretName': 'secret_b'}},
- ],
- },
+ assert result == {
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {
+ 'labels': {'app': 'myapp'},
+ 'name': 'myapp-pod-cf4a56d281014217b0272af6216feb48',
+ 'namespace': 'default',
},
- )
+ 'spec': {
+ 'containers': [
+ {
+ 'command': ['sh', '-c', 'echo Hello Kubernetes!'],
+ 'env': [
+ {'name': 'ENVIRONMENT', 'value': 'prod'},
+ {'name': 'LOG_LEVEL', 'value': 'warning'},
+ {
+ 'name': 'TARGET',
+ 'valueFrom': {'secretKeyRef': {'key': 'source_b', 'name': 'secret_b'}},
+ },
+ ],
+ 'envFrom': [
+ {'configMapRef': {'name': 'configmap_a'}},
+ {'secretRef': {'name': 'secret_a'}},
+ ],
+ 'image': 'busybox',
+ 'name': 'base',
+ 'ports': [{'containerPort': 1234, 'name': 'foo'}],
+ 'resources': {'limits': {'memory': '200Mi'}, 'requests': {'memory': '100Mi'}},
+ 'volumeMounts': [
+ {'mountPath': '/airflow/xcom', 'name': 'xcom'},
+ {
+ 'mountPath': '/etc/foo',
+ 'name': 'secretvol' + str(static_uuid),
+ 'readOnly': True,
+ },
+ ],
+ },
+ {
+ 'command': ['sh', '-c', 'trap "exit 0" INT; while true; do sleep 30; done;'],
+ 'image': 'alpine',
+ 'name': 'airflow-xcom-sidecar',
+ 'resources': {'requests': {'cpu': '1m'}},
+ 'volumeMounts': [{'mountPath': '/airflow/xcom', 'name': 'xcom'}],
+ },
+ ],
+ 'hostNetwork': True,
+ 'imagePullSecrets': [{'name': 'pull_secret_a'}, {'name': 'pull_secret_b'}],
+ 'securityContext': {'fsGroup': 2000, 'runAsUser': 1000},
+ 'volumes': [
+ {'emptyDir': {}, 'name': 'xcom'},
+ {'name': 'secretvol' + str(static_uuid), 'secret': {'secretName': 'secret_b'}},
+ ],
+ },
+ }
diff --git a/tests/kubernetes/test_client.py b/tests/kubernetes/test_client.py
index e0bb372b36b3f..bf5dcfc3d0a2d 100644
--- a/tests/kubernetes/test_client.py
+++ b/tests/kubernetes/test_client.py
@@ -54,8 +54,8 @@ def test_enable_tcp_keepalive(self):
_enable_tcp_keepalive()
- self.assertEqual(HTTPConnection.default_socket_options, expected_http_connection_options)
- self.assertEqual(HTTPSConnection.default_socket_options, expected_https_connection_options)
+ assert HTTPConnection.default_socket_options == expected_http_connection_options
+ assert HTTPSConnection.default_socket_options == expected_https_connection_options
def test_disable_verify_ssl(self):
configuration = Configuration()
diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py
index 156a0e20e1ffc..0c2efe537292d 100644
--- a/tests/kubernetes/test_pod_generator.py
+++ b/tests/kubernetes/test_pod_generator.py
@@ -19,6 +19,7 @@
import uuid
from unittest import mock
+import pytest
from dateutil import parser
from kubernetes.client import ApiClient, models as k8s
@@ -184,7 +185,7 @@ def test_gen_pod_extract_xcom(self, mock_uuid):
result_dict = self.k8s_client.sanitize_for_serialization(result)
expected_dict = self.k8s_client.sanitize_for_serialization(self.expected)
- self.assertEqual(result_dict, expected_dict)
+ assert result_dict == expected_dict
def test_from_obj(self):
result = PodGenerator.from_obj(
@@ -216,28 +217,23 @@ def test_from_obj(self):
)
result = self.k8s_client.sanitize_for_serialization(result)
- self.assertEqual(
- {
- 'apiVersion': 'v1',
- 'kind': 'Pod',
- 'metadata': {
- 'name': 'foo',
- 'annotations': {'test': 'annotation'},
- },
- 'spec': {
- 'containers': [
- {
- 'name': 'base',
- 'volumeMounts': [
- {'mountPath': '/foo/', 'name': 'example-kubernetes-test-volume'}
- ],
- }
- ],
- 'volumes': [{'hostPath': {'path': '/tmp/'}, 'name': 'example-kubernetes-test-volume'}],
- },
+ assert {
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {
+ 'name': 'foo',
+ 'annotations': {'test': 'annotation'},
},
- result,
- )
+ 'spec': {
+ 'containers': [
+ {
+ 'name': 'base',
+ 'volumeMounts': [{'mountPath': '/foo/', 'name': 'example-kubernetes-test-volume'}],
+ }
+ ],
+ 'volumes': [{'hostPath': {'path': '/tmp/'}, 'name': 'example-kubernetes-test-volume'}],
+ },
+ } == result
result = PodGenerator.from_obj(
{
"KubernetesExecutor": {
@@ -293,40 +289,33 @@ def test_from_obj(self):
'volumes': [{'hostPath': '/tmp/', 'name': 'example-kubernetes-test-volume'}],
},
}
- self.assertEqual(
- result_from_pod,
- expected_from_pod,
- "There was a discrepency between KubernetesExecutor and pod_override",
- )
+ assert (
+ result_from_pod == expected_from_pod
+ ), "There was a discrepency between KubernetesExecutor and pod_override"
- self.assertEqual(
- {
- 'apiVersion': 'v1',
- 'kind': 'Pod',
- 'metadata': {
- 'annotations': {'test': 'annotation'},
- },
- 'spec': {
- 'containers': [
- {
- 'args': [],
- 'command': [],
- 'env': [],
- 'envFrom': [],
- 'name': 'base',
- 'ports': [],
- 'volumeMounts': [
- {'mountPath': '/foo/', 'name': 'example-kubernetes-test-volume'}
- ],
- }
- ],
- 'hostNetwork': False,
- 'imagePullSecrets': [],
- 'volumes': [{'hostPath': {'path': '/tmp/'}, 'name': 'example-kubernetes-test-volume'}],
- },
+ assert {
+ 'apiVersion': 'v1',
+ 'kind': 'Pod',
+ 'metadata': {
+ 'annotations': {'test': 'annotation'},
},
- result,
- )
+ 'spec': {
+ 'containers': [
+ {
+ 'args': [],
+ 'command': [],
+ 'env': [],
+ 'envFrom': [],
+ 'name': 'base',
+ 'ports': [],
+ 'volumeMounts': [{'mountPath': '/foo/', 'name': 'example-kubernetes-test-volume'}],
+ }
+ ],
+ 'hostNetwork': False,
+ 'imagePullSecrets': [],
+ 'volumes': [{'hostPath': {'path': '/tmp/'}, 'name': 'example-kubernetes-test-volume'}],
+ },
+ } == result
@mock.patch('uuid.uuid4')
def test_reconcile_pods_empty_mutator_pod(self, mock_uuid):
@@ -341,11 +330,11 @@ def test_reconcile_pods_empty_mutator_pod(self, mock_uuid):
base_pod.metadata.name = name
result = PodGenerator.reconcile_pods(base_pod, mutator_pod)
- self.assertEqual(base_pod, result)
+ assert base_pod == result
mutator_pod = k8s.V1Pod()
result = PodGenerator.reconcile_pods(base_pod, mutator_pod)
- self.assertEqual(base_pod, result)
+ assert base_pod == result
@mock.patch('uuid.uuid4')
def test_reconcile_pods(self, mock_uuid):
@@ -401,7 +390,7 @@ def test_reconcile_pods(self, mock_uuid):
result_dict = self.k8s_client.sanitize_for_serialization(result)
expected_dict = self.k8s_client.sanitize_for_serialization(expected)
- self.assertEqual(result_dict, expected_dict)
+ assert result_dict == expected_dict
@mock.patch('uuid.uuid4')
def test_construct_pod(self, mock_uuid):
@@ -449,7 +438,7 @@ def test_construct_pod(self, mock_uuid):
result_dict = self.k8s_client.sanitize_for_serialization(result)
expected_dict = self.k8s_client.sanitize_for_serialization(self.expected)
- self.assertEqual(expected_dict, result_dict)
+ assert expected_dict == result_dict
@mock.patch('uuid.uuid4')
def test_construct_pod_empty_executor_config(self, mock_uuid):
@@ -483,27 +472,27 @@ def test_construct_pod_empty_executor_config(self, mock_uuid):
k8s.V1EnvVar(name="AIRFLOW_IS_K8S_EXECUTOR_POD", value='True')
)
worker_config_result = self.k8s_client.sanitize_for_serialization(worker_config)
- self.assertEqual(worker_config_result, sanitized_result)
+ assert worker_config_result == sanitized_result
def test_merge_objects_empty(self):
annotations = {'foo1': 'bar1'}
base_obj = k8s.V1ObjectMeta(annotations=annotations)
client_obj = None
res = merge_objects(base_obj, client_obj)
- self.assertEqual(base_obj, res)
+ assert base_obj == res
client_obj = k8s.V1ObjectMeta()
res = merge_objects(base_obj, client_obj)
- self.assertEqual(base_obj, res)
+ assert base_obj == res
client_obj = k8s.V1ObjectMeta(annotations=annotations)
base_obj = None
res = merge_objects(base_obj, client_obj)
- self.assertEqual(client_obj, res)
+ assert client_obj == res
base_obj = k8s.V1ObjectMeta()
res = merge_objects(base_obj, client_obj)
- self.assertEqual(client_obj, res)
+ assert client_obj == res
def test_merge_objects(self):
base_annotations = {'foo1': 'bar1'}
@@ -513,7 +502,7 @@ def test_merge_objects(self):
client_obj = k8s.V1ObjectMeta(annotations=client_annotations)
res = merge_objects(base_obj, client_obj)
client_obj.labels = base_labels
- self.assertEqual(client_obj, res)
+ assert client_obj == res
def test_extend_object_field_empty(self):
ports = [k8s.V1ContainerPort(container_port=1, name='port')]
@@ -521,21 +510,21 @@ def test_extend_object_field_empty(self):
client_obj = k8s.V1Container(name='client_container')
res = extend_object_field(base_obj, client_obj, 'ports')
client_obj.ports = ports
- self.assertEqual(client_obj, res)
+ assert client_obj == res
base_obj = k8s.V1Container(name='base_container')
client_obj = k8s.V1Container(name='base_container', ports=ports)
res = extend_object_field(base_obj, client_obj, 'ports')
- self.assertEqual(client_obj, res)
+ assert client_obj == res
def test_extend_object_field_not_list(self):
base_obj = k8s.V1Container(name='base_container', image='image')
client_obj = k8s.V1Container(name='client_container')
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
extend_object_field(base_obj, client_obj, 'image')
base_obj = k8s.V1Container(name='base_container')
client_obj = k8s.V1Container(name='client_container', image='image')
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
extend_object_field(base_obj, client_obj, 'image')
def test_extend_object_field(self):
@@ -545,21 +534,21 @@ def test_extend_object_field(self):
client_obj = k8s.V1Container(name='client_container', ports=client_ports)
res = extend_object_field(base_obj, client_obj, 'ports')
client_obj.ports = base_ports + client_ports
- self.assertEqual(client_obj, res)
+ assert client_obj == res
def test_reconcile_containers_empty(self):
base_objs = [k8s.V1Container(name='base_container')]
client_objs = []
res = PodGenerator.reconcile_containers(base_objs, client_objs)
- self.assertEqual(base_objs, res)
+ assert base_objs == res
client_objs = [k8s.V1Container(name='client_container')]
base_objs = []
res = PodGenerator.reconcile_containers(base_objs, client_objs)
- self.assertEqual(client_objs, res)
+ assert client_objs == res
res = PodGenerator.reconcile_containers([], [])
- self.assertEqual(res, [])
+ assert res == []
def test_reconcile_containers(self):
base_ports = [k8s.V1ContainerPort(container_port=1, name='base_port')]
@@ -574,7 +563,7 @@ def test_reconcile_containers(self):
]
res = PodGenerator.reconcile_containers(base_objs, client_objs)
client_objs[0].ports = base_ports + client_ports
- self.assertEqual(client_objs, res)
+ assert client_objs == res
base_ports = [k8s.V1ContainerPort(container_port=1, name='base_port')]
base_objs = [
@@ -589,18 +578,18 @@ def test_reconcile_containers(self):
res = PodGenerator.reconcile_containers(base_objs, client_objs)
client_objs[0].ports = base_ports + client_ports
client_objs[1].image = 'base_image'
- self.assertEqual(client_objs, res)
+ assert client_objs == res
def test_reconcile_specs_empty(self):
base_spec = k8s.V1PodSpec(containers=[])
client_spec = None
res = PodGenerator.reconcile_specs(base_spec, client_spec)
- self.assertEqual(base_spec, res)
+ assert base_spec == res
base_spec = None
client_spec = k8s.V1PodSpec(containers=[])
res = PodGenerator.reconcile_specs(base_spec, client_spec)
- self.assertEqual(client_spec, res)
+ assert client_spec == res
def test_reconcile_specs(self):
base_objs = [k8s.V1Container(name='base_container1', image='base_image')]
@@ -610,13 +599,13 @@ def test_reconcile_specs(self):
res = PodGenerator.reconcile_specs(base_spec, client_spec)
client_spec.containers = [k8s.V1Container(name='client_container1', image='base_image')]
client_spec.active_deadline_seconds = 100
- self.assertEqual(client_spec, res)
+ assert client_spec == res
def test_deserialize_model_file(self):
path = sys.path[0] + '/tests/kubernetes/pod.yaml'
result = PodGenerator.deserialize_model_file(path)
sanitized_res = self.k8s_client.sanitize_for_serialization(result)
- self.assertEqual(sanitized_res, self.deserialize_result)
+ assert sanitized_res == self.deserialize_result
def test_deserialize_model_string(self):
fixture = """
@@ -639,12 +628,12 @@ def test_deserialize_model_string(self):
"""
result = PodGenerator.deserialize_model_file(fixture)
sanitized_res = self.k8s_client.sanitize_for_serialization(result)
- self.assertEqual(sanitized_res, self.deserialize_result)
+ assert sanitized_res == self.deserialize_result
def test_validate_pod_generator(self):
- with self.assertRaises(AirflowConfigException):
+ with pytest.raises(AirflowConfigException):
PodGenerator(pod=k8s.V1Pod(), pod_template_file='k')
- with self.assertRaises(AirflowConfigException):
+ with pytest.raises(AirflowConfigException):
PodGenerator()
PodGenerator(pod_template_file='tests/kubernetes/pod.yaml')
PodGenerator(pod=k8s.V1Pod())
diff --git a/tests/kubernetes/test_pod_launcher.py b/tests/kubernetes/test_pod_launcher.py
index 2b55a7770bdcf..9e7cc82651d9a 100644
--- a/tests/kubernetes/test_pod_launcher.py
+++ b/tests/kubernetes/test_pod_launcher.py
@@ -17,6 +17,7 @@
import unittest
from unittest import mock
+import pytest
from requests.exceptions import BaseHTTPError
from airflow.exceptions import AirflowException
@@ -32,7 +33,7 @@ def test_read_pod_logs_successfully_returns_logs(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.read_namespaced_pod_log.return_value = mock.sentinel.logs
logs = self.pod_launcher.read_pod_logs(mock.sentinel)
- self.assertEqual(mock.sentinel.logs, logs)
+ assert mock.sentinel.logs == logs
def test_read_pod_logs_retries_successfully(self):
mock.sentinel.metadata = mock.MagicMock()
@@ -41,7 +42,7 @@ def test_read_pod_logs_retries_successfully(self):
mock.sentinel.logs,
]
logs = self.pod_launcher.read_pod_logs(mock.sentinel)
- self.assertEqual(mock.sentinel.logs, logs)
+ assert mock.sentinel.logs == logs
self.mock_kube_client.read_namespaced_pod_log.assert_has_calls(
[
mock.call(
@@ -70,13 +71,14 @@ def test_read_pod_logs_retries_fails(self):
BaseHTTPError('Boom'),
BaseHTTPError('Boom'),
]
- self.assertRaises(AirflowException, self.pod_launcher.read_pod_logs, mock.sentinel)
+ with pytest.raises(AirflowException):
+ self.pod_launcher.read_pod_logs(mock.sentinel)
def test_read_pod_logs_successfully_with_tail_lines(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.read_namespaced_pod_log.side_effect = [mock.sentinel.logs]
logs = self.pod_launcher.read_pod_logs(mock.sentinel, tail_lines=100)
- self.assertEqual(mock.sentinel.logs, logs)
+ assert mock.sentinel.logs == logs
self.mock_kube_client.read_namespaced_pod_log.assert_has_calls(
[
mock.call(
@@ -95,7 +97,7 @@ def test_read_pod_logs_successfully_with_since_seconds(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.read_namespaced_pod_log.side_effect = [mock.sentinel.logs]
logs = self.pod_launcher.read_pod_logs(mock.sentinel, since_seconds=2)
- self.assertEqual(mock.sentinel.logs, logs)
+ assert mock.sentinel.logs == logs
self.mock_kube_client.read_namespaced_pod_log.assert_has_calls(
[
mock.call(
@@ -114,7 +116,7 @@ def test_read_pod_events_successfully_returns_events(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.list_namespaced_event.return_value = mock.sentinel.events
events = self.pod_launcher.read_pod_events(mock.sentinel)
- self.assertEqual(mock.sentinel.events, events)
+ assert mock.sentinel.events == events
def test_read_pod_events_retries_successfully(self):
mock.sentinel.metadata = mock.MagicMock()
@@ -123,7 +125,7 @@ def test_read_pod_events_retries_successfully(self):
mock.sentinel.events,
]
events = self.pod_launcher.read_pod_events(mock.sentinel)
- self.assertEqual(mock.sentinel.events, events)
+ assert mock.sentinel.events == events
self.mock_kube_client.list_namespaced_event.assert_has_calls(
[
mock.call(
@@ -144,13 +146,14 @@ def test_read_pod_events_retries_fails(self):
BaseHTTPError('Boom'),
BaseHTTPError('Boom'),
]
- self.assertRaises(AirflowException, self.pod_launcher.read_pod_events, mock.sentinel)
+ with pytest.raises(AirflowException):
+ self.pod_launcher.read_pod_events(mock.sentinel)
def test_read_pod_returns_logs(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.read_namespaced_pod.return_value = mock.sentinel.pod_info
pod_info = self.pod_launcher.read_pod(mock.sentinel)
- self.assertEqual(mock.sentinel.pod_info, pod_info)
+ assert mock.sentinel.pod_info == pod_info
def test_read_pod_retries_successfully(self):
mock.sentinel.metadata = mock.MagicMock()
@@ -159,7 +162,7 @@ def test_read_pod_retries_successfully(self):
mock.sentinel.pod_info,
]
pod_info = self.pod_launcher.read_pod(mock.sentinel)
- self.assertEqual(mock.sentinel.pod_info, pod_info)
+ assert mock.sentinel.pod_info == pod_info
self.mock_kube_client.read_namespaced_pod.assert_has_calls(
[
mock.call(mock.sentinel.metadata.name, mock.sentinel.metadata.namespace),
@@ -174,17 +177,16 @@ def test_read_pod_retries_fails(self):
BaseHTTPError('Boom'),
BaseHTTPError('Boom'),
]
- self.assertRaises(AirflowException, self.pod_launcher.read_pod, mock.sentinel)
+ with pytest.raises(AirflowException):
+ self.pod_launcher.read_pod(mock.sentinel)
def test_parse_log_line(self):
timestamp, message = self.pod_launcher.parse_log_line(
'2020-10-08T14:16:17.793417674Z Valid message\n'
)
- self.assertEqual(timestamp, '2020-10-08T14:16:17.793417674Z')
- self.assertEqual(message, 'Valid message')
+ assert timestamp == '2020-10-08T14:16:17.793417674Z'
+ assert message == 'Valid message'
- self.assertRaises(
- Exception,
- self.pod_launcher.parse_log_line('2020-10-08T14:16:17.793417674ZInvalid message\n'),
- )
+ with pytest.raises(Exception):
+ self.pod_launcher.parse_log_line('2020-10-08T14:16:17.793417674ZInvalidmessage\n')
diff --git a/tests/kubernetes/test_refresh_config.py b/tests/kubernetes/test_refresh_config.py
index ca3740efd2257..a0753e2e4b209 100644
--- a/tests/kubernetes/test_refresh_config.py
+++ b/tests/kubernetes/test_refresh_config.py
@@ -17,6 +17,7 @@
from unittest import TestCase
+import pytest
from pendulum.parsing import ParserError
from airflow.kubernetes.refresh_config import _parse_timestamp
@@ -25,12 +26,12 @@
class TestRefreshKubeConfigLoader(TestCase):
def test_parse_timestamp_should_convert_z_timezone_to_unix_timestamp(self):
ts = _parse_timestamp("2020-01-13T13:42:20Z")
- self.assertEqual(1578922940, ts)
+ assert 1578922940 == ts
def test_parse_timestamp_should_convert_regular_timezone_to_unix_timestamp(self):
ts = _parse_timestamp("2020-01-13T13:42:20+0600")
- self.assertEqual(1578922940, ts)
+ assert 1578922940 == ts
def test_parse_timestamp_should_throw_exception(self):
- with self.assertRaises(ParserError):
+ with pytest.raises(ParserError):
_parse_timestamp("foobar")
diff --git a/tests/lineage/test_lineage.py b/tests/lineage/test_lineage.py
index cc28fc4b3ea18..350a8be29a4d5 100644
--- a/tests/lineage/test_lineage.py
+++ b/tests/lineage/test_lineage.py
@@ -66,29 +66,29 @@ def test_lineage(self):
# prepare with manual inlets and outlets
op1.pre_execute(ctx1)
- self.assertEqual(len(op1.inlets), 1)
- self.assertEqual(op1.inlets[0].url, f1s.format(DEFAULT_DATE))
+ assert len(op1.inlets) == 1
+ assert op1.inlets[0].url == f1s.format(DEFAULT_DATE)
- self.assertEqual(len(op1.outlets), 1)
- self.assertEqual(op1.outlets[0].url, f2s.format(DEFAULT_DATE))
+ assert len(op1.outlets) == 1
+ assert op1.outlets[0].url == f2s.format(DEFAULT_DATE)
# post process with no backend
op1.post_execute(ctx1)
op2.pre_execute(ctx2)
- self.assertEqual(len(op2.inlets), 0)
+ assert len(op2.inlets) == 0
op2.post_execute(ctx2)
op3.pre_execute(ctx3)
- self.assertEqual(len(op3.inlets), 1)
- self.assertEqual(op3.inlets[0].url, f2s.format(DEFAULT_DATE))
- self.assertEqual(op3.outlets[0], file3)
+ assert len(op3.inlets) == 1
+ assert op3.inlets[0].url == f2s.format(DEFAULT_DATE)
+ assert op3.outlets[0] == file3
op3.post_execute(ctx3)
# skip 4
op5.pre_execute(ctx5)
- self.assertEqual(len(op5.inlets), 2)
+ assert len(op5.inlets) == 2
op5.post_execute(ctx5)
def test_lineage_render(self):
@@ -109,5 +109,5 @@ def test_lineage_render(self):
ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE}
op1.pre_execute(ctx1)
- self.assertEqual(op1.inlets[0].url, f1s.format(DEFAULT_DATE))
- self.assertEqual(op1.outlets[0].url, f1s.format(DEFAULT_DATE))
+ assert op1.inlets[0].url == f1s.format(DEFAULT_DATE)
+ assert op1.outlets[0].url == f1s.format(DEFAULT_DATE)
diff --git a/tests/macros/test_hive.py b/tests/macros/test_hive.py
index 64310fc9ef9f5..c72c7f7d74239 100644
--- a/tests/macros/test_hive.py
+++ b/tests/macros/test_hive.py
@@ -32,10 +32,10 @@ def test_closest_ds_partition(self):
target_dt = datetime.strptime('2017-04-27', '%Y-%m-%d')
date_list = [date1, date2, date3, date4, date5]
- self.assertEqual("2017-04-26", str(hive._closest_date(target_dt, date_list, True)))
- self.assertEqual("2017-04-28", str(hive._closest_date(target_dt, date_list, False)))
+ assert "2017-04-26" == str(hive._closest_date(target_dt, date_list, True))
+ assert "2017-04-28" == str(hive._closest_date(target_dt, date_list, False))
# when before is not set, the closest date should be returned
- self.assertEqual("2017-04-26", str(hive._closest_date(target_dt, [date1, date2, date3, date5], None)))
- self.assertEqual("2017-04-28", str(hive._closest_date(target_dt, [date1, date2, date4, date5])))
- self.assertEqual("2017-04-26", str(hive._closest_date(target_dt, date_list)))
+ assert "2017-04-26" == str(hive._closest_date(target_dt, [date1, date2, date3, date5], None))
+ assert "2017-04-28" == str(hive._closest_date(target_dt, [date1, date2, date4, date5]))
+ assert "2017-04-26" == str(hive._closest_date(target_dt, date_list))
diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py
index 02541b76ecb0f..95607a668f274 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -133,7 +133,7 @@ def test_render_template(self, content, context, expected_output):
task = DummyOperator(task_id="op1")
result = task.render_template(content, context)
- self.assertEqual(result, expected_output)
+ assert result == expected_output
def test_render_template_fields(self):
"""Verify if operator attributes are correctly templated."""
@@ -141,13 +141,13 @@ def test_render_template_fields(self):
task = MockOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}")
# Assert nothing is templated yet
- self.assertEqual(task.arg1, "{{ foo }}")
- self.assertEqual(task.arg2, "{{ bar }}")
+ assert task.arg1 == "{{ foo }}"
+ assert task.arg2 == "{{ bar }}"
# Trigger templating and verify if attributes are templated correctly
task.render_template_fields(context={"foo": "footemplated", "bar": "bartemplated"})
- self.assertEqual(task.arg1, "footemplated")
- self.assertEqual(task.arg2, "bartemplated")
+ assert task.arg1 == "footemplated"
+ assert task.arg2 == "bartemplated"
@parameterized.expand(
[
@@ -167,7 +167,7 @@ def test_render_template_fields_with_dag_settings(self, dag_kwargs, content, con
task = DummyOperator(task_id="op1")
result = task.render_template(content, context)
- self.assertEqual(result, expected_output)
+ assert result == expected_output
@parameterized.expand([(object(),), (uuid.uuid4(),)])
def test_render_template_fields_no_change(self, content):
@@ -176,14 +176,14 @@ def test_render_template_fields_no_change(self, content):
task = DummyOperator(task_id="op1")
result = task.render_template(content, {"foo": "bar"})
- self.assertEqual(content, result)
+ assert content == result
def test_render_template_field_undefined_default(self):
"""Test render_template with template_undefined unchanged."""
with DAG("test-dag", start_date=DEFAULT_DATE):
task = DummyOperator(task_id="op1")
- with self.assertRaises(jinja2.UndefinedError):
+ with pytest.raises(jinja2.UndefinedError):
task.render_template("{{ foo }}", {})
def test_render_template_field_undefined_strict(self):
@@ -191,7 +191,7 @@ def test_render_template_field_undefined_strict(self):
with DAG("test-dag", start_date=DEFAULT_DATE, template_undefined=jinja2.StrictUndefined):
task = DummyOperator(task_id="op1")
- with self.assertRaises(jinja2.UndefinedError):
+ with pytest.raises(jinja2.UndefinedError):
task.render_template("{{ foo }}", {})
def test_render_template_field_undefined_not_strict(self):
@@ -199,26 +199,24 @@ def test_render_template_field_undefined_not_strict(self):
with DAG("test-dag", start_date=DEFAULT_DATE, template_undefined=jinja2.Undefined):
task = DummyOperator(task_id="op1")
- self.assertEqual(task.render_template("{{ foo }}", {}), "")
+ assert task.render_template("{{ foo }}", {}) == ""
def test_nested_template_fields_declared_must_exist(self):
"""Test render_template when a nested template field is missing."""
with DAG("test-dag", start_date=DEFAULT_DATE):
task = DummyOperator(task_id="op1")
- with self.assertRaises(AttributeError) as e:
+ with pytest.raises(AttributeError) as ctx:
task.render_template(ClassWithCustomAttributes(template_fields=["missing_field"]), {})
- self.assertEqual(
- "'ClassWithCustomAttributes' object has no attribute 'missing_field'", str(e.exception)
- )
+ assert "'ClassWithCustomAttributes' object has no attribute 'missing_field'" == str(ctx.value)
def test_jinja_invalid_expression_is_just_propagated(self):
"""Test render_template propagates Jinja invalid expression errors."""
with DAG("test-dag", start_date=DEFAULT_DATE):
task = DummyOperator(task_id="op1")
- with self.assertRaises(jinja2.exceptions.TemplateSyntaxError):
+ with pytest.raises(jinja2.exceptions.TemplateSyntaxError):
task.render_template("{{ invalid expression }}", {})
@mock.patch("jinja2.Environment", autospec=True)
@@ -228,7 +226,7 @@ def test_jinja_env_creation(self, mock_jinja_env):
task = MockOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}")
task.render_template_fields(context={"foo": "whatever", "bar": "whatever"})
- self.assertEqual(mock_jinja_env.call_count, 1)
+ assert mock_jinja_env.call_count == 1
def test_set_jinja_env_additional_option(self):
"""Test render_template given various input types."""
@@ -238,7 +236,7 @@ def test_set_jinja_env_additional_option(self):
task = DummyOperator(task_id="op1")
result = task.render_template("{{ foo }}\n\n", {"foo": "bar"})
- self.assertEqual(result, "bar\n\n")
+ assert result == "bar\n\n"
def test_override_jinja_env_option(self):
"""Test render_template given various input types."""
@@ -246,16 +244,16 @@ def test_override_jinja_env_option(self):
task = DummyOperator(task_id="op1")
result = task.render_template("{{ foo }}", {"foo": "bar"})
- self.assertEqual(result, "bar")
+ assert result == "bar"
def test_default_resources(self):
task = DummyOperator(task_id="default-resources")
- self.assertIsNone(task.resources)
+ assert task.resources is None
def test_custom_resources(self):
task = DummyOperator(task_id="custom-resources", resources={"cpus": 1, "ram": 1024})
- self.assertEqual(task.resources.cpus.qty, 1)
- self.assertEqual(task.resources.ram.qty, 1024)
+ assert task.resources.cpus.qty == 1
+ assert task.resources.ram.qty == 1024
def test_default_email_on_actions(self):
test_task = DummyOperator(task_id='test_default_email_on_actions')
@@ -279,28 +277,28 @@ def test_cross_downstream(self):
cross_downstream(from_tasks=start_tasks, to_tasks=end_tasks)
for start_task in start_tasks:
- self.assertCountEqual(start_task.get_direct_relatives(upstream=False), end_tasks)
+ assert set(start_task.get_direct_relatives(upstream=False)) == set(end_tasks)
def test_chain(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
[op1, op2, op3, op4, op5, op6] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 7)]
chain(op1, [op2, op3], [op4, op5], op6)
- self.assertCountEqual([op2, op3], op1.get_direct_relatives(upstream=False))
- self.assertEqual([op4], op2.get_direct_relatives(upstream=False))
- self.assertEqual([op5], op3.get_direct_relatives(upstream=False))
- self.assertCountEqual([op4, op5], op6.get_direct_relatives(upstream=True))
+ assert {op2, op3} == set(op1.get_direct_relatives(upstream=False))
+ assert [op4] == op2.get_direct_relatives(upstream=False)
+ assert [op5] == op3.get_direct_relatives(upstream=False)
+ assert {op4, op5} == set(op6.get_direct_relatives(upstream=True))
def test_chain_not_support_type(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
[op1, op2] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 3)]
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
chain([op1, op2], 1) # noqa
def test_chain_different_length_iterable(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
[op1, op2, op3, op4, op5] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 6)]
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
chain([op1, op2], [op3, op4, op5])
def test_lineage_composition(self):
@@ -319,43 +317,43 @@ def test_lineage_composition(self):
# note: operator precedence still applies
inlet > task1 | (task2 > outlet)
- self.assertEqual(task1.get_inlet_defs(), [inlet])
- self.assertEqual(task2.get_inlet_defs(), [task1.task_id])
- self.assertEqual(task2.get_outlet_defs(), [outlet])
+ assert task1.get_inlet_defs() == [inlet]
+ assert task2.get_inlet_defs() == [task1.task_id]
+ assert task2.get_outlet_defs() == [outlet]
fail = ClassWithCustomAttributes()
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
fail > task1
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
task1 > fail
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
fail | task1
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
task1 | fail
task3 = DummyOperator(task_id="op3", dag=dag)
extra = File(url="extra")
[inlet, extra] > task3
- self.assertEqual(task3.get_inlet_defs(), [inlet, extra])
+ assert task3.get_inlet_defs() == [inlet, extra]
task1.supports_lineage = False
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
task1 | task3
- self.assertEqual(task2.supports_lineage, False)
+ assert task2.supports_lineage is False
task2 | task3
- self.assertEqual(len(task3.get_inlet_defs()), 3)
+ assert len(task3.get_inlet_defs()) == 3
task4 = DummyOperator(task_id="op4", dag=dag)
task4 > [inlet, outlet, extra]
- self.assertEqual(task4.get_outlet_defs(), [inlet, outlet, extra])
+ assert task4.get_outlet_defs() == [inlet, outlet, extra]
def test_warnings_are_properly_propagated(self):
- with self.assertWarns(DeprecationWarning) as warns:
+ with pytest.warns(DeprecationWarning) as warnings:
DeprecatedOperator(task_id="test")
- assert len(warns.warnings) == 1
- warning = warns.warnings[0]
+ assert len(warnings) == 1
+ warning = warnings[0]
# Here we check that the trace points to the place
# where the deprecated class was used
assert warning.filename == __file__
diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py
index c64c7bef4b020..f54bacce339f4 100644
--- a/tests/models/test_cleartasks.py
+++ b/tests/models/test_cleartasks.py
@@ -62,10 +62,10 @@ def test_clear_task_instances(self):
ti0.refresh_from_db()
ti1.refresh_from_db()
# Next try to run will be try 2
- self.assertEqual(ti0.try_number, 2)
- self.assertEqual(ti0.max_tries, 1)
- self.assertEqual(ti1.try_number, 2)
- self.assertEqual(ti1.max_tries, 3)
+ assert ti0.try_number == 2
+ assert ti0.max_tries == 1
+ assert ti1.try_number == 2
+ assert ti1.max_tries == 3
def test_clear_task_instances_without_task(self):
dag = DAG(
@@ -89,8 +89,8 @@ def test_clear_task_instances_without_task(self):
# Remove the task from dag.
dag.task_dict = {}
- self.assertFalse(dag.has_task(task0.task_id))
- self.assertFalse(dag.has_task(task1.task_id))
+ assert not dag.has_task(task0.task_id)
+ assert not dag.has_task(task1.task_id)
with create_session() as session:
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
@@ -100,10 +100,10 @@ def test_clear_task_instances_without_task(self):
ti0.refresh_from_db()
ti1.refresh_from_db()
# Next try to run will be try 2
- self.assertEqual(ti0.try_number, 2)
- self.assertEqual(ti0.max_tries, 1)
- self.assertEqual(ti1.try_number, 2)
- self.assertEqual(ti1.max_tries, 2)
+ assert ti0.try_number == 2
+ assert ti0.max_tries == 1
+ assert ti1.try_number == 2
+ assert ti1.max_tries == 2
def test_clear_task_instances_without_dag(self):
dag = DAG(
@@ -133,10 +133,10 @@ def test_clear_task_instances_without_dag(self):
ti0.refresh_from_db()
ti1.refresh_from_db()
# Next try to run will be try 2
- self.assertEqual(ti0.try_number, 2)
- self.assertEqual(ti0.max_tries, 1)
- self.assertEqual(ti1.try_number, 2)
- self.assertEqual(ti1.max_tries, 2)
+ assert ti0.try_number == 2
+ assert ti0.max_tries == 1
+ assert ti1.try_number == 2
+ assert ti1.max_tries == 2
def test_dag_clear(self):
dag = DAG(
@@ -152,33 +152,33 @@ def test_dag_clear(self):
)
# Next try to run will be try 1
- self.assertEqual(ti0.try_number, 1)
+ assert ti0.try_number == 1
ti0.run()
- self.assertEqual(ti0.try_number, 2)
+ assert ti0.try_number == 2
dag.clear()
ti0.refresh_from_db()
- self.assertEqual(ti0.try_number, 2)
- self.assertEqual(ti0.state, State.NONE)
- self.assertEqual(ti0.max_tries, 1)
+ assert ti0.try_number == 2
+ assert ti0.state == State.NONE
+ assert ti0.max_tries == 1
task1 = DummyOperator(task_id='test_dag_clear_task_1', owner='test', dag=dag, retries=2)
ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
- self.assertEqual(ti1.max_tries, 2)
+ assert ti1.max_tries == 2
ti1.try_number = 1
# Next try will be 2
ti1.run()
- self.assertEqual(ti1.try_number, 3)
- self.assertEqual(ti1.max_tries, 2)
+ assert ti1.try_number == 3
+ assert ti1.max_tries == 2
dag.clear()
ti0.refresh_from_db()
ti1.refresh_from_db()
# after clear dag, ti2 should show attempt 3 of 5
- self.assertEqual(ti1.max_tries, 4)
- self.assertEqual(ti1.try_number, 3)
+ assert ti1.max_tries == 4
+ assert ti1.try_number == 3
# after clear dag, ti1 should show attempt 2 of 2
- self.assertEqual(ti0.try_number, 2)
- self.assertEqual(ti0.max_tries, 1)
+ assert ti0.try_number == 2
+ assert ti0.max_tries == 1
def test_dags_clear(self):
# setup
@@ -207,32 +207,32 @@ def test_dags_clear(self):
# test clear all dags
for i in range(num_of_dags):
tis[i].run()
- self.assertEqual(tis[i].state, State.SUCCESS)
- self.assertEqual(tis[i].try_number, 2)
- self.assertEqual(tis[i].max_tries, 0)
+ assert tis[i].state == State.SUCCESS
+ assert tis[i].try_number == 2
+ assert tis[i].max_tries == 0
DAG.clear_dags(dags)
for i in range(num_of_dags):
tis[i].refresh_from_db()
- self.assertEqual(tis[i].state, State.NONE)
- self.assertEqual(tis[i].try_number, 2)
- self.assertEqual(tis[i].max_tries, 1)
+ assert tis[i].state == State.NONE
+ assert tis[i].try_number == 2
+ assert tis[i].max_tries == 1
# test dry_run
for i in range(num_of_dags):
tis[i].run()
- self.assertEqual(tis[i].state, State.SUCCESS)
- self.assertEqual(tis[i].try_number, 3)
- self.assertEqual(tis[i].max_tries, 1)
+ assert tis[i].state == State.SUCCESS
+ assert tis[i].try_number == 3
+ assert tis[i].max_tries == 1
DAG.clear_dags(dags, dry_run=True)
for i in range(num_of_dags):
tis[i].refresh_from_db()
- self.assertEqual(tis[i].state, State.SUCCESS)
- self.assertEqual(tis[i].try_number, 3)
- self.assertEqual(tis[i].max_tries, 1)
+ assert tis[i].state == State.SUCCESS
+ assert tis[i].try_number == 3
+ assert tis[i].max_tries == 1
# test only_failed
from random import randint
@@ -247,13 +247,13 @@ def test_dags_clear(self):
for i in range(num_of_dags):
tis[i].refresh_from_db()
if i != failed_dag_idx:
- self.assertEqual(tis[i].state, State.SUCCESS)
- self.assertEqual(tis[i].try_number, 3)
- self.assertEqual(tis[i].max_tries, 1)
+ assert tis[i].state == State.SUCCESS
+ assert tis[i].try_number == 3
+ assert tis[i].max_tries == 1
else:
- self.assertEqual(tis[i].state, State.NONE)
- self.assertEqual(tis[i].try_number, 3)
- self.assertEqual(tis[i].max_tries, 2)
+ assert tis[i].state == State.NONE
+ assert tis[i].try_number == 3
+ assert tis[i].max_tries == 2
def test_operator_clear(self):
dag = DAG(
@@ -277,16 +277,16 @@ def test_operator_clear(self):
ti2.run()
# Dependency not met
- self.assertEqual(ti2.try_number, 1)
- self.assertEqual(ti2.max_tries, 1)
+ assert ti2.try_number == 1
+ assert ti2.max_tries == 1
op2.clear(upstream=True)
ti1.run()
ti2.run(ignore_ti_state=True)
- self.assertEqual(ti1.try_number, 2)
+ assert ti1.try_number == 2
# max_tries is 0 because there is no task instance in db for ti1
# so clear won't change the max_tries.
- self.assertEqual(ti1.max_tries, 0)
- self.assertEqual(ti2.try_number, 2)
+ assert ti1.max_tries == 0
+ assert ti2.try_number == 2
# try_number (0) + retries(1)
- self.assertEqual(ti2.max_tries, 1)
+ assert ti2.max_tries == 1
diff --git a/tests/models/test_connection.py b/tests/models/test_connection.py
index 2723c3fc0846b..a96b89a804086 100644
--- a/tests/models/test_connection.py
+++ b/tests/models/test_connection.py
@@ -21,6 +21,7 @@
from collections import namedtuple
from unittest import mock
+import pytest
import sqlalchemy
from cryptography.fernet import Fernet
from parameterized import parameterized
@@ -72,8 +73,8 @@ def test_connection_extra_no_encryption(self):
encryption.
"""
test_connection = Connection(extra='testextra')
- self.assertFalse(test_connection.is_extra_encrypted)
- self.assertEqual(test_connection.extra, 'testextra')
+ assert not test_connection.is_extra_encrypted
+ assert test_connection.extra == 'testextra'
@conf_vars({('core', 'fernet_key'): Fernet.generate_key().decode()})
def test_connection_extra_with_encryption(self):
@@ -81,8 +82,8 @@ def test_connection_extra_with_encryption(self):
Tests extras on a new connection with encryption.
"""
test_connection = Connection(extra='testextra')
- self.assertTrue(test_connection.is_extra_encrypted)
- self.assertEqual(test_connection.extra, 'testextra')
+ assert test_connection.is_extra_encrypted
+ assert test_connection.extra == 'testextra'
def test_connection_extra_with_encryption_rotate_fernet_key(self):
"""
@@ -93,20 +94,20 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
with conf_vars({('core', 'fernet_key'): key1.decode()}):
test_connection = Connection(extra='testextra')
- self.assertTrue(test_connection.is_extra_encrypted)
- self.assertEqual(test_connection.extra, 'testextra')
- self.assertEqual(Fernet(key1).decrypt(test_connection._extra.encode()), b'testextra')
+ assert test_connection.is_extra_encrypted
+ assert test_connection.extra == 'testextra'
+ assert Fernet(key1).decrypt(test_connection._extra.encode()) == b'testextra'
# Test decrypt of old value with new key
with conf_vars({('core', 'fernet_key'): ','.join([key2.decode(), key1.decode()])}):
crypto._fernet = None
- self.assertEqual(test_connection.extra, 'testextra')
+ assert test_connection.extra == 'testextra'
# Test decrypt of new value with new key
test_connection.rotate_fernet_key()
- self.assertTrue(test_connection.is_extra_encrypted)
- self.assertEqual(test_connection.extra, 'testextra')
- self.assertEqual(Fernet(key2).decrypt(test_connection._extra.encode()), b'testextra')
+ assert test_connection.is_extra_encrypted
+ assert test_connection.extra == 'testextra'
+ assert Fernet(key2).decrypt(test_connection._extra.encode()) == b'testextra'
test_from_uri_params = [
UriTestCaseConfig(
@@ -297,11 +298,11 @@ def test_connection_from_uri(self, test_config: UriTestCaseConfig):
for conn_attr, expected_val in test_config.test_conn_attributes.items():
actual_val = getattr(connection, conn_attr)
if expected_val is None:
- self.assertIsNone(expected_val)
+ assert expected_val is None
if isinstance(expected_val, dict):
- self.assertDictEqual(expected_val, actual_val)
+ assert expected_val == actual_val
else:
- self.assertEqual(expected_val, actual_val)
+ assert expected_val == actual_val
# pylint: disable=undefined-variable
@parameterized.expand([(x,) for x in test_from_uri_params], UriTestCaseConfig.uri_test_name)
@@ -317,13 +318,13 @@ def test_connection_get_uri_from_uri(self, test_config: UriTestCaseConfig):
connection = Connection(uri=test_config.test_uri)
generated_uri = connection.get_uri()
new_conn = Connection(uri=generated_uri)
- self.assertEqual(connection.conn_type, new_conn.conn_type)
- self.assertEqual(connection.login, new_conn.login)
- self.assertEqual(connection.password, new_conn.password)
- self.assertEqual(connection.host, new_conn.host)
- self.assertEqual(connection.port, new_conn.port)
- self.assertEqual(connection.schema, new_conn.schema)
- self.assertDictEqual(connection.extra_dejson, new_conn.extra_dejson)
+ assert connection.conn_type == new_conn.conn_type
+ assert connection.login == new_conn.login
+ assert connection.password == new_conn.password
+ assert connection.host == new_conn.host
+ assert connection.port == new_conn.port
+ assert connection.schema == new_conn.schema
+ assert connection.extra_dejson == new_conn.extra_dejson
# pylint: disable=undefined-variable
@parameterized.expand([(x,) for x in test_from_uri_params], UriTestCaseConfig.uri_test_name)
@@ -350,11 +351,11 @@ def test_connection_get_uri_from_conn(self, test_config: UriTestCaseConfig):
for conn_attr, expected_val in test_config.test_conn_attributes.items():
actual_val = getattr(new_conn, conn_attr)
if expected_val is None:
- self.assertIsNone(expected_val)
+ assert expected_val is None
if isinstance(expected_val, dict):
- self.assertDictEqual(expected_val, actual_val)
+ assert expected_val == actual_val
else:
- self.assertEqual(expected_val, actual_val)
+ assert expected_val == actual_val
@parameterized.expand(
[
@@ -426,12 +427,12 @@ def test_connection_get_uri_from_conn(self, test_config: UriTestCaseConfig):
def test_connection_from_with_auth_info(self, uri, uri_parts):
connection = Connection(uri=uri)
- self.assertEqual(connection.conn_type, uri_parts.conn_type)
- self.assertEqual(connection.login, uri_parts.login)
- self.assertEqual(connection.password, uri_parts.password)
- self.assertEqual(connection.host, uri_parts.host)
- self.assertEqual(connection.port, uri_parts.port)
- self.assertEqual(connection.schema, uri_parts.schema)
+ assert connection.conn_type == uri_parts.conn_type
+ assert connection.login == uri_parts.login
+ assert connection.password == uri_parts.password
+ assert connection.host == uri_parts.host
+ assert connection.port == uri_parts.port
+ assert connection.schema == uri_parts.schema
@mock.patch.dict(
'os.environ',
@@ -441,11 +442,11 @@ def test_connection_from_with_auth_info(self, uri, uri_parts):
)
def test_using_env_var(self):
conn = SqliteHook.get_connection(conn_id='test_uri')
- self.assertEqual('ec2.compute.com', conn.host)
- self.assertEqual('the_database', conn.schema)
- self.assertEqual('username', conn.login)
- self.assertEqual('password', conn.password)
- self.assertEqual(5432, conn.port)
+ assert 'ec2.compute.com' == conn.host
+ assert 'the_database' == conn.schema
+ assert 'username' == conn.login
+ assert 'password' == conn.password
+ assert 5432 == conn.port
@mock.patch.dict(
'os.environ',
@@ -455,11 +456,11 @@ def test_using_env_var(self):
)
def test_using_unix_socket_env_var(self):
conn = SqliteHook.get_connection(conn_id='test_uri_no_creds')
- self.assertEqual('ec2.compute.com', conn.host)
- self.assertEqual('the_database', conn.schema)
- self.assertIsNone(conn.login)
- self.assertIsNone(conn.password)
- self.assertIsNone(conn.port)
+ assert 'ec2.compute.com' == conn.host
+ assert 'the_database' == conn.schema
+ assert conn.login is None
+ assert conn.password is None
+ assert conn.port is None
def test_param_setup(self):
conn = Connection(
@@ -470,15 +471,15 @@ def test_param_setup(self):
password='airflow',
schema='airflow',
)
- self.assertEqual('localhost', conn.host)
- self.assertEqual('airflow', conn.schema)
- self.assertEqual('airflow', conn.login)
- self.assertEqual('airflow', conn.password)
- self.assertIsNone(conn.port)
+ assert 'localhost' == conn.host
+ assert 'airflow' == conn.schema
+ assert 'airflow' == conn.login
+ assert 'airflow' == conn.password
+ assert conn.port is None
def test_env_var_priority(self):
conn = SqliteHook.get_connection(conn_id='airflow_db')
- self.assertNotEqual('ec2.compute.com', conn.host)
+ assert 'ec2.compute.com' != conn.host
with mock.patch.dict(
'os.environ',
@@ -487,11 +488,11 @@ def test_env_var_priority(self):
},
):
conn = SqliteHook.get_connection(conn_id='airflow_db')
- self.assertEqual('ec2.compute.com', conn.host)
- self.assertEqual('the_database', conn.schema)
- self.assertEqual('username', conn.login)
- self.assertEqual('password', conn.password)
- self.assertEqual(5432, conn.port)
+ assert 'ec2.compute.com' == conn.host
+ assert 'the_database' == conn.schema
+ assert 'username' == conn.login
+ assert 'password' == conn.password
+ assert 5432 == conn.port
@mock.patch.dict(
'os.environ',
@@ -503,10 +504,10 @@ def test_env_var_priority(self):
def test_dbapi_get_uri(self):
conn = BaseHook.get_connection(conn_id='test_uri')
hook = conn.get_hook()
- self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', hook.get_uri())
+ assert 'postgres://username:password@ec2.compute.com:5432/the_database' == hook.get_uri()
conn2 = BaseHook.get_connection(conn_id='test_uri_no_creds')
hook2 = conn2.get_hook()
- self.assertEqual('postgres://ec2.compute.com/the_database', hook2.get_uri())
+ assert 'postgres://ec2.compute.com/the_database' == hook2.get_uri()
@mock.patch.dict(
'os.environ',
@@ -519,8 +520,8 @@ def test_dbapi_get_sqlalchemy_engine(self):
conn = BaseHook.get_connection(conn_id='test_uri')
hook = conn.get_hook()
engine = hook.get_sqlalchemy_engine()
- self.assertIsInstance(engine, sqlalchemy.engine.Engine)
- self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', str(engine.url))
+ assert isinstance(engine, sqlalchemy.engine.Engine)
+ assert 'postgres://username:password@ec2.compute.com:5432/the_database' == str(engine.url)
@mock.patch.dict(
'os.environ',
@@ -539,9 +540,9 @@ def test_get_connections_env_var(self):
assert conns[0].port == 5432
def test_connection_mixed(self):
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowException,
- re.escape(
+ match=re.escape(
"You must create an object using the URI or individual values (conn_type, host, login, "
"password, schema, port or extra).You can't mix these two ways to create this object."
),
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 295e9eccba751..c929cd897daa9 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -102,8 +102,8 @@ def test_params_not_passed_is_empty_dict(self):
"""
dag = models.DAG('test-dag')
- self.assertEqual(dict, type(dag.params))
- self.assertEqual(0, len(dag.params))
+ assert isinstance(dag.params, dict)
+ assert 0 == len(dag.params)
def test_params_passed_and_params_in_default_args_no_override(self):
"""
@@ -119,13 +119,13 @@ def test_params_passed_and_params_in_default_args_no_override(self):
params_combined = params1.copy()
params_combined.update(params2)
- self.assertEqual(params_combined, dag.params)
+ assert params_combined == dag.params
def test_dag_invalid_default_view(self):
"""
Test invalid `default_view` of DAG initialization
"""
- with self.assertRaisesRegex(AirflowException, 'Invalid values of dag.default_view: only support'):
+ with pytest.raises(AirflowException, match='Invalid values of dag.default_view: only support'):
models.DAG(dag_id='test-invalid-default_view', default_view='airflow')
def test_dag_default_view_default_value(self):
@@ -133,13 +133,13 @@ def test_dag_default_view_default_value(self):
Test `default_view` default value of DAG initialization
"""
dag = models.DAG(dag_id='test-default_default_view')
- self.assertEqual(conf.get('webserver', 'dag_default_view').lower(), dag.default_view)
+ assert conf.get('webserver', 'dag_default_view').lower() == dag.default_view
def test_dag_invalid_orientation(self):
"""
Test invalid `orientation` of DAG initialization
"""
- with self.assertRaisesRegex(AirflowException, 'Invalid values of dag.orientation: only support'):
+ with pytest.raises(AirflowException, match='Invalid values of dag.orientation: only support'):
models.DAG(dag_id='test-invalid-orientation', orientation='airflow')
def test_dag_orientation_default_value(self):
@@ -147,7 +147,7 @@ def test_dag_orientation_default_value(self):
Test `orientation` default value of DAG initialization
"""
dag = models.DAG(dag_id='test-default_orientation')
- self.assertEqual(conf.get('webserver', 'dag_orientation'), dag.orientation)
+ assert conf.get('webserver', 'dag_orientation') == dag.orientation
def test_dag_as_context_manager(self):
"""
@@ -162,32 +162,32 @@ def test_dag_as_context_manager(self):
op1 = DummyOperator(task_id='op1')
op2 = DummyOperator(task_id='op2', dag=dag2)
- self.assertIs(op1.dag, dag)
- self.assertEqual(op1.owner, 'owner1')
- self.assertIs(op2.dag, dag2)
- self.assertEqual(op2.owner, 'owner2')
+ assert op1.dag is dag
+ assert op1.owner == 'owner1'
+ assert op2.dag is dag2
+ assert op2.owner == 'owner2'
with dag2:
op3 = DummyOperator(task_id='op3')
- self.assertIs(op3.dag, dag2)
- self.assertEqual(op3.owner, 'owner2')
+ assert op3.dag is dag2
+ assert op3.owner == 'owner2'
with dag:
with dag2:
op4 = DummyOperator(task_id='op4')
op5 = DummyOperator(task_id='op5')
- self.assertIs(op4.dag, dag2)
- self.assertIs(op5.dag, dag)
- self.assertEqual(op4.owner, 'owner2')
- self.assertEqual(op5.owner, 'owner1')
+ assert op4.dag is dag2
+ assert op5.dag is dag
+ assert op4.owner == 'owner2'
+ assert op5.owner == 'owner1'
with DAG('creating_dag_in_cm', start_date=DEFAULT_DATE) as dag:
DummyOperator(task_id='op6')
- self.assertEqual(dag.dag_id, 'creating_dag_in_cm')
- self.assertEqual(dag.tasks[0].task_id, 'op6')
+ assert dag.dag_id == 'creating_dag_in_cm'
+ assert dag.tasks[0].task_id == 'op6'
with dag:
with dag:
@@ -196,9 +196,9 @@ def test_dag_as_context_manager(self):
op9 = DummyOperator(task_id='op8')
op9.dag = dag2
- self.assertEqual(op7.dag, dag)
- self.assertEqual(op8.dag, dag)
- self.assertEqual(op9.dag, dag2)
+ assert op7.dag == dag
+ assert op8.dag == dag
+ assert op9.dag == dag2
def test_dag_topological_sort_include_subdag_tasks(self):
child_dag = DAG(
@@ -227,11 +227,11 @@ def test_dag_topological_sort_include_subdag_tasks(self):
topological_list = parent_dag.topological_sort(include_subdag_tasks=True)
- self.assertTrue(self._occur_before('a_parent', 'child_dag', topological_list))
- self.assertTrue(self._occur_before('child_dag', 'a_child', topological_list))
- self.assertTrue(self._occur_before('child_dag', 'b_child', topological_list))
- self.assertTrue(self._occur_before('a_child', 'b_parent', topological_list))
- self.assertTrue(self._occur_before('b_child', 'b_parent', topological_list))
+ assert self._occur_before('a_parent', 'child_dag', topological_list)
+ assert self._occur_before('child_dag', 'a_child', topological_list)
+ assert self._occur_before('child_dag', 'b_child', topological_list)
+ assert self._occur_before('a_child', 'b_parent', topological_list)
+ assert self._occur_before('b_child', 'b_parent', topological_list)
def test_dag_topological_sort1(self):
dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
@@ -251,13 +251,13 @@ def test_dag_topological_sort1(self):
logging.info(topological_list)
tasks = [op2, op3, op4]
- self.assertTrue(topological_list[0] in tasks)
+ assert topological_list[0] in tasks
tasks.remove(topological_list[0])
- self.assertTrue(topological_list[1] in tasks)
+ assert topological_list[1] in tasks
tasks.remove(topological_list[1])
- self.assertTrue(topological_list[2] in tasks)
+ assert topological_list[2] in tasks
tasks.remove(topological_list[2])
- self.assertTrue(topological_list[3] == op1)
+ assert topological_list[3] == op1
def test_dag_topological_sort2(self):
dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
@@ -281,25 +281,25 @@ def test_dag_topological_sort2(self):
logging.info(topological_list)
set1 = [op4, op5]
- self.assertTrue(topological_list[0] in set1)
+ assert topological_list[0] in set1
set1.remove(topological_list[0])
set2 = [op1, op2]
set2.extend(set1)
- self.assertTrue(topological_list[1] in set2)
+ assert topological_list[1] in set2
set2.remove(topological_list[1])
- self.assertTrue(topological_list[2] in set2)
+ assert topological_list[2] in set2
set2.remove(topological_list[2])
- self.assertTrue(topological_list[3] in set2)
+ assert topological_list[3] in set2
- self.assertTrue(topological_list[4] == op3)
+ assert topological_list[4] == op3
def test_dag_topological_sort_dag_without_tasks(self):
dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
- self.assertEqual((), dag.topological_sort())
+ assert () == dag.topological_sort()
def test_dag_naive_start_date_string(self):
DAG('DAG', default_args={'start_date': '2019-06-01'})
@@ -321,14 +321,14 @@ def test_dag_start_date_propagates_to_end_date(self):
dag = DAG(
'DAG', default_args={'start_date': '2019-06-05T00:00:00+05:00', 'end_date': '2019-06-05T00:00:00'}
)
- self.assertEqual(dag.default_args['start_date'], dag.default_args['end_date'])
- self.assertEqual(dag.default_args['start_date'].tzinfo, dag.default_args['end_date'].tzinfo)
+ assert dag.default_args['start_date'] == dag.default_args['end_date']
+ assert dag.default_args['start_date'].tzinfo == dag.default_args['end_date'].tzinfo
def test_dag_naive_default_args_start_date(self):
dag = DAG('DAG', default_args={'start_date': datetime.datetime(2018, 1, 1)})
- self.assertEqual(dag.timezone, settings.TIMEZONE)
+ assert dag.timezone == settings.TIMEZONE
dag = DAG('DAG', start_date=datetime.datetime(2018, 1, 1))
- self.assertEqual(dag.timezone, settings.TIMEZONE)
+ assert dag.timezone == settings.TIMEZONE
def test_dag_none_default_args_start_date(self):
"""
@@ -336,7 +336,7 @@ def test_dag_none_default_args_start_date(self):
works.
"""
dag = DAG('DAG', default_args={'start_date': None})
- self.assertEqual(dag.timezone, settings.TIMEZONE)
+ assert dag.timezone == settings.TIMEZONE
def test_dag_task_priority_weight_total(self):
width = 5
@@ -365,7 +365,7 @@ def test_dag_task_priority_weight_total(self):
correct_weight = ((depth - (task_depth + 1)) * width + 1) * weight
calculated_weight = task.priority_weight_total
- self.assertEqual(calculated_weight, correct_weight)
+ assert calculated_weight == correct_weight
def test_dag_task_priority_weight_total_using_upstream(self):
# Same test as above except use 'upstream' for weight calculation
@@ -399,7 +399,7 @@ def test_dag_task_priority_weight_total_using_upstream(self):
correct_weight = (task_depth * width + 1) * weight
calculated_weight = task.priority_weight_total
- self.assertEqual(calculated_weight, correct_weight)
+ assert calculated_weight == correct_weight
def test_dag_task_priority_weight_total_using_absolute(self):
# Same test as above except use 'absolute' for weight calculation
@@ -429,12 +429,12 @@ def test_dag_task_priority_weight_total_using_absolute(self):
# the sum of each stages after this task + itself
correct_weight = weight
calculated_weight = task.priority_weight_total
- self.assertEqual(calculated_weight, correct_weight)
+ assert calculated_weight == correct_weight
def test_dag_task_invalid_weight_rule(self):
# Test if we enter an invalid weight rule
with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
DummyOperator(task_id='should_fail', weight_rule='no rule')
def test_get_num_task_instances(self):
@@ -459,29 +459,18 @@ def test_get_num_task_instances(self):
session.merge(ti4)
session.commit()
- self.assertEqual(0, DAG.get_num_task_instances(test_dag_id, ['fakename'], session=session))
- self.assertEqual(4, DAG.get_num_task_instances(test_dag_id, [test_task_id], session=session))
- self.assertEqual(
- 4, DAG.get_num_task_instances(test_dag_id, ['fakename', test_task_id], session=session)
+ assert 0 == DAG.get_num_task_instances(test_dag_id, ['fakename'], session=session)
+ assert 4 == DAG.get_num_task_instances(test_dag_id, [test_task_id], session=session)
+ assert 4 == DAG.get_num_task_instances(test_dag_id, ['fakename', test_task_id], session=session)
+ assert 1 == DAG.get_num_task_instances(test_dag_id, [test_task_id], states=[None], session=session)
+ assert 2 == DAG.get_num_task_instances(
+ test_dag_id, [test_task_id], states=[State.RUNNING], session=session
)
- self.assertEqual(
- 1, DAG.get_num_task_instances(test_dag_id, [test_task_id], states=[None], session=session)
+ assert 3 == DAG.get_num_task_instances(
+ test_dag_id, [test_task_id], states=[None, State.RUNNING], session=session
)
- self.assertEqual(
- 2,
- DAG.get_num_task_instances(test_dag_id, [test_task_id], states=[State.RUNNING], session=session),
- )
- self.assertEqual(
- 3,
- DAG.get_num_task_instances(
- test_dag_id, [test_task_id], states=[None, State.RUNNING], session=session
- ),
- )
- self.assertEqual(
- 4,
- DAG.get_num_task_instances(
- test_dag_id, [test_task_id], states=[None, State.QUEUED, State.RUNNING], session=session
- ),
+ assert 4 == DAG.get_num_task_instances(
+ test_dag_id, [test_task_id], states=[None, State.QUEUED, State.RUNNING], session=session
)
session.close()
@@ -492,8 +481,8 @@ def jinja_udf(name):
dag = models.DAG('test-dag', start_date=DEFAULT_DATE, user_defined_filters={"hello": jinja_udf})
jinja_env = dag.get_template_env()
- self.assertIn('hello', jinja_env.filters)
- self.assertEqual(jinja_env.filters['hello'], jinja_udf)
+ assert 'hello' in jinja_env.filters
+ assert jinja_env.filters['hello'] == jinja_udf # pylint: disable=comparison-with-callable
def test_resolve_template_files_value(self):
@@ -511,7 +500,7 @@ def test_resolve_template_files_value(self):
task.template_ext = ('.template',)
task.resolve_template_files()
- self.assertEqual(task.test_field, '{{ ds }}')
+ assert task.test_field == '{{ ds }}'
def test_resolve_template_files_list(self):
@@ -529,7 +518,7 @@ def test_resolve_template_files_list(self):
task.template_ext = ('.template',)
task.resolve_template_files()
- self.assertEqual(task.test_field, ['{{ ds }}', 'some_string'])
+ assert task.test_field == ['{{ ds }}', 'some_string']
def test_following_previous_schedule(self):
"""
@@ -537,9 +526,7 @@ def test_following_previous_schedule(self):
"""
local_tz = pendulum.timezone('Europe/Zurich')
start = local_tz.convert(datetime.datetime(2018, 10, 28, 2, 55), dst_rule=pendulum.PRE_TRANSITION)
- self.assertEqual(
- start.isoformat(), "2018-10-28T02:55:00+02:00", "Pre-condition: start date is in DST"
- )
+ assert start.isoformat() == "2018-10-28T02:55:00+02:00", "Pre-condition: start date is in DST"
utc = timezone.convert_to_utc(start)
@@ -547,19 +534,19 @@ def test_following_previous_schedule(self):
_next = dag.following_schedule(utc)
next_local = local_tz.convert(_next)
- self.assertEqual(_next.isoformat(), "2018-10-28T01:00:00+00:00")
- self.assertEqual(next_local.isoformat(), "2018-10-28T02:00:00+01:00")
+ assert _next.isoformat() == "2018-10-28T01:00:00+00:00"
+ assert next_local.isoformat() == "2018-10-28T02:00:00+01:00"
prev = dag.previous_schedule(utc)
prev_local = local_tz.convert(prev)
- self.assertEqual(prev_local.isoformat(), "2018-10-28T02:50:00+02:00")
+ assert prev_local.isoformat() == "2018-10-28T02:50:00+02:00"
prev = dag.previous_schedule(_next)
prev_local = local_tz.convert(prev)
- self.assertEqual(prev_local.isoformat(), "2018-10-28T02:55:00+02:00")
- self.assertEqual(prev, utc)
+ assert prev_local.isoformat() == "2018-10-28T02:55:00+02:00"
+ assert prev == utc
def test_following_previous_schedule_daily_dag_cest_to_cet(self):
"""
@@ -575,20 +562,20 @@ def test_following_previous_schedule_daily_dag_cest_to_cet(self):
prev = dag.previous_schedule(utc)
prev_local = local_tz.convert(prev)
- self.assertEqual(prev_local.isoformat(), "2018-10-26T03:00:00+02:00")
- self.assertEqual(prev.isoformat(), "2018-10-26T01:00:00+00:00")
+ assert prev_local.isoformat() == "2018-10-26T03:00:00+02:00"
+ assert prev.isoformat() == "2018-10-26T01:00:00+00:00"
_next = dag.following_schedule(utc)
next_local = local_tz.convert(_next)
- self.assertEqual(next_local.isoformat(), "2018-10-28T03:00:00+01:00")
- self.assertEqual(_next.isoformat(), "2018-10-28T02:00:00+00:00")
+ assert next_local.isoformat() == "2018-10-28T03:00:00+01:00"
+ assert _next.isoformat() == "2018-10-28T02:00:00+00:00"
prev = dag.previous_schedule(_next)
prev_local = local_tz.convert(prev)
- self.assertEqual(prev_local.isoformat(), "2018-10-27T03:00:00+02:00")
- self.assertEqual(prev.isoformat(), "2018-10-27T01:00:00+00:00")
+ assert prev_local.isoformat() == "2018-10-27T03:00:00+02:00"
+ assert prev.isoformat() == "2018-10-27T01:00:00+00:00"
def test_following_previous_schedule_daily_dag_cet_to_cest(self):
"""
@@ -604,20 +591,20 @@ def test_following_previous_schedule_daily_dag_cet_to_cest(self):
prev = dag.previous_schedule(utc)
prev_local = local_tz.convert(prev)
- self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00")
- self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00")
+ assert prev_local.isoformat() == "2018-03-24T03:00:00+01:00"
+ assert prev.isoformat() == "2018-03-24T02:00:00+00:00"
_next = dag.following_schedule(utc)
next_local = local_tz.convert(_next)
- self.assertEqual(next_local.isoformat(), "2018-03-25T03:00:00+02:00")
- self.assertEqual(_next.isoformat(), "2018-03-25T01:00:00+00:00")
+ assert next_local.isoformat() == "2018-03-25T03:00:00+02:00"
+ assert _next.isoformat() == "2018-03-25T01:00:00+00:00"
prev = dag.previous_schedule(_next)
prev_local = local_tz.convert(prev)
- self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00")
- self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00")
+ assert prev_local.isoformat() == "2018-03-24T03:00:00+01:00"
+ assert prev.isoformat() == "2018-03-24T02:00:00+00:00"
def test_following_schedule_relativedelta(self):
"""
@@ -629,20 +616,19 @@ def test_following_schedule_relativedelta(self):
dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE))
_next = dag.following_schedule(TEST_DATE)
- self.assertEqual(_next.isoformat(), "2015-01-02T01:00:00+00:00")
+ assert _next.isoformat() == "2015-01-02T01:00:00+00:00"
_next = dag.following_schedule(_next)
- self.assertEqual(_next.isoformat(), "2015-01-02T02:00:00+00:00")
+ assert _next.isoformat() == "2015-01-02T02:00:00+00:00"
def test_dagtag_repr(self):
clear_db_dags()
dag = DAG('dag-test-dagtag', start_date=DEFAULT_DATE, tags=['tag-1', 'tag-2'])
dag.sync_to_db()
with create_session() as session:
- self.assertEqual(
- {'tag-1', 'tag-2'},
- {repr(t) for t in session.query(DagTag).filter(DagTag.dag_id == 'dag-test-dagtag').all()},
- )
+ assert {'tag-1', 'tag-2'} == {
+ repr(t) for t in session.query(DagTag).filter(DagTag.dag_id == 'dag-test-dagtag').all()
+ }
def test_bulk_write_to_db(self):
clear_db_dags()
@@ -651,19 +637,15 @@ def test_bulk_write_to_db(self):
with assert_queries_count(5):
DAG.bulk_write_to_db(dags)
with create_session() as session:
- self.assertEqual(
- {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'},
- {row[0] for row in session.query(DagModel.dag_id).all()},
- )
- self.assertEqual(
- {
- ('dag-bulk-sync-0', 'test-dag'),
- ('dag-bulk-sync-1', 'test-dag'),
- ('dag-bulk-sync-2', 'test-dag'),
- ('dag-bulk-sync-3', 'test-dag'),
- },
- set(session.query(DagTag.dag_id, DagTag.name).all()),
- )
+ assert {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'} == {
+ row[0] for row in session.query(DagModel.dag_id).all()
+ }
+ assert {
+ ('dag-bulk-sync-0', 'test-dag'),
+ ('dag-bulk-sync-1', 'test-dag'),
+ ('dag-bulk-sync-2', 'test-dag'),
+ ('dag-bulk-sync-3', 'test-dag'),
+ } == set(session.query(DagTag.dag_id, DagTag.name).all())
# Re-sync should do fewer queries
with assert_queries_count(3):
DAG.bulk_write_to_db(dags)
@@ -675,42 +657,34 @@ def test_bulk_write_to_db(self):
with assert_queries_count(4):
DAG.bulk_write_to_db(dags)
with create_session() as session:
- self.assertEqual(
- {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'},
- {row[0] for row in session.query(DagModel.dag_id).all()},
- )
- self.assertEqual(
- {
- ('dag-bulk-sync-0', 'test-dag'),
- ('dag-bulk-sync-0', 'test-dag2'),
- ('dag-bulk-sync-1', 'test-dag'),
- ('dag-bulk-sync-1', 'test-dag2'),
- ('dag-bulk-sync-2', 'test-dag'),
- ('dag-bulk-sync-2', 'test-dag2'),
- ('dag-bulk-sync-3', 'test-dag'),
- ('dag-bulk-sync-3', 'test-dag2'),
- },
- set(session.query(DagTag.dag_id, DagTag.name).all()),
- )
+ assert {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'} == {
+ row[0] for row in session.query(DagModel.dag_id).all()
+ }
+ assert {
+ ('dag-bulk-sync-0', 'test-dag'),
+ ('dag-bulk-sync-0', 'test-dag2'),
+ ('dag-bulk-sync-1', 'test-dag'),
+ ('dag-bulk-sync-1', 'test-dag2'),
+ ('dag-bulk-sync-2', 'test-dag'),
+ ('dag-bulk-sync-2', 'test-dag2'),
+ ('dag-bulk-sync-3', 'test-dag'),
+ ('dag-bulk-sync-3', 'test-dag2'),
+ } == set(session.query(DagTag.dag_id, DagTag.name).all())
# Removing tags
for dag in dags:
dag.tags.remove("test-dag")
with assert_queries_count(4):
DAG.bulk_write_to_db(dags)
with create_session() as session:
- self.assertEqual(
- {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'},
- {row[0] for row in session.query(DagModel.dag_id).all()},
- )
- self.assertEqual(
- {
- ('dag-bulk-sync-0', 'test-dag2'),
- ('dag-bulk-sync-1', 'test-dag2'),
- ('dag-bulk-sync-2', 'test-dag2'),
- ('dag-bulk-sync-3', 'test-dag2'),
- },
- set(session.query(DagTag.dag_id, DagTag.name).all()),
- )
+ assert {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'} == {
+ row[0] for row in session.query(DagModel.dag_id).all()
+ }
+ assert {
+ ('dag-bulk-sync-0', 'test-dag2'),
+ ('dag-bulk-sync-1', 'test-dag2'),
+ ('dag-bulk-sync-2', 'test-dag2'),
+ ('dag-bulk-sync-3', 'test-dag2'),
+ } == set(session.query(DagTag.dag_id, DagTag.name).all())
def test_bulk_write_to_db_max_active_runs(self):
"""
@@ -765,17 +739,17 @@ def test_sync_to_db(self):
dag.sync_to_db(session=session)
orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one()
- self.assertEqual(set(orm_dag.owners.split(', ')), {'owner1', 'owner2'})
- self.assertTrue(orm_dag.is_active)
- self.assertIsNotNone(orm_dag.default_view)
- self.assertEqual(orm_dag.default_view, conf.get('webserver', 'dag_default_view').lower())
- self.assertEqual(orm_dag.safe_dag_id, 'dag')
+ assert set(orm_dag.owners.split(', ')) == {'owner1', 'owner2'}
+ assert orm_dag.is_active
+ assert orm_dag.default_view is not None
+ assert orm_dag.default_view == conf.get('webserver', 'dag_default_view').lower()
+ assert orm_dag.safe_dag_id == 'dag'
orm_subdag = session.query(DagModel).filter(DagModel.dag_id == 'dag.subtask').one()
- self.assertEqual(set(orm_subdag.owners.split(', ')), {'owner1', 'owner2'})
- self.assertTrue(orm_subdag.is_active)
- self.assertEqual(orm_subdag.safe_dag_id, 'dag__dot__subtask')
- self.assertEqual(orm_subdag.fileloc, orm_dag.fileloc)
+ assert set(orm_subdag.owners.split(', ')) == {'owner1', 'owner2'}
+ assert orm_subdag.is_active
+ assert orm_subdag.safe_dag_id == 'dag__dot__subtask'
+ assert orm_subdag.fileloc == orm_dag.fileloc
session.close()
def test_sync_to_db_default_view(self):
@@ -798,8 +772,8 @@ def test_sync_to_db_default_view(self):
dag.sync_to_db(session=session)
orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one()
- self.assertIsNotNone(orm_dag.default_view)
- self.assertEqual(orm_dag.default_view, "graph")
+ assert orm_dag.default_view is not None
+ assert orm_dag.default_view == "graph"
session.close()
@provide_session
@@ -841,13 +815,10 @@ def test_is_paused_subdag(self, session):
.all()
)
- self.assertEqual(
- {
- (dag_id, False),
- (subdag_id, False),
- },
- set(unpaused_dags),
- )
+ assert {
+ (dag_id, False),
+ (subdag_id, False),
+ } == set(unpaused_dags)
DagModel.get_dagmodel(dag.dag_id).set_is_paused(is_paused=True, including_subdags=False)
@@ -859,13 +830,10 @@ def test_is_paused_subdag(self, session):
.all()
)
- self.assertEqual(
- {
- (dag_id, True),
- (subdag_id, False),
- },
- set(paused_dags),
- )
+ assert {
+ (dag_id, True),
+ (subdag_id, False),
+ } == set(paused_dags)
DagModel.get_dagmodel(dag.dag_id).set_is_paused(is_paused=True)
@@ -877,23 +845,20 @@ def test_is_paused_subdag(self, session):
.all()
)
- self.assertEqual(
- {
- (dag_id, True),
- (subdag_id, True),
- },
- set(paused_dags),
- )
+ assert {
+ (dag_id, True),
+ (subdag_id, True),
+ } == set(paused_dags)
def test_existing_dag_is_paused_upon_creation(self):
dag = DAG('dag_paused')
dag.sync_to_db()
- self.assertFalse(dag.get_is_paused())
+ assert not dag.get_is_paused()
dag = DAG('dag_paused', is_paused_upon_creation=True)
dag.sync_to_db()
# Since the dag existed before, it should not follow the pause flag upon creation
- self.assertFalse(dag.get_is_paused())
+ assert not dag.get_is_paused()
def test_new_dag_is_paused_upon_creation(self):
dag = DAG('new_nonexisting_dag', is_paused_upon_creation=True)
@@ -902,7 +867,7 @@ def test_new_dag_is_paused_upon_creation(self):
orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'new_nonexisting_dag').one()
# Since the dag didn't exist before, it should follow the pause flag upon creation
- self.assertTrue(orm_dag.is_paused)
+ assert orm_dag.is_paused
session.close()
def test_existing_dag_default_view(self):
@@ -911,8 +876,8 @@ def test_existing_dag_default_view(self):
session.add(DagModel(dag_id='dag_default_view_old', default_view=None))
session.commit()
orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag_default_view_old').one()
- self.assertIsNone(orm_dag.default_view)
- self.assertEqual(orm_dag.get_default_view(), conf.get('webserver', 'dag_default_view').lower())
+ assert orm_dag.default_view is None
+ assert orm_dag.get_default_view() == conf.get('webserver', 'dag_default_view').lower()
def test_dag_is_deactivated_upon_dagfile_deletion(self):
dag_id = 'old_existing_dag'
@@ -928,13 +893,13 @@ def test_dag_is_deactivated_upon_dagfile_deletion(self):
orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one()
- self.assertTrue(orm_dag.is_active)
- self.assertEqual(orm_dag.fileloc, dag_fileloc)
+ assert orm_dag.is_active
+ assert orm_dag.fileloc == dag_fileloc
DagModel.deactivate_deleted_dags(list_py_file_paths(settings.DAGS_FOLDER))
orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one()
- self.assertFalse(orm_dag.is_active)
+ assert not orm_dag.is_active
# pylint: disable=no-member
session.execute(DagModel.__table__.delete().where(DagModel.dag_id == dag_id))
@@ -945,10 +910,10 @@ def test_dag_naive_default_args_start_date_with_timezone(self):
default_args = {'start_date': datetime.datetime(2018, 1, 1, tzinfo=local_tz)}
dag = DAG('DAG', default_args=default_args)
- self.assertEqual(dag.timezone.name, local_tz.name)
+ assert dag.timezone.name == local_tz.name
dag = DAG('DAG', default_args=default_args)
- self.assertEqual(dag.timezone.name, local_tz.name)
+ assert dag.timezone.name == local_tz.name
def test_roots(self):
"""Verify if dag.roots returns the root tasks of a DAG."""
@@ -960,7 +925,7 @@ def test_roots(self):
op5 = DummyOperator(task_id="t5")
[op1, op2] >> op3 >> [op4, op5]
- self.assertCountEqual(dag.roots, [op1, op2])
+ assert set(dag.roots) == {op1, op2}
def test_leaves(self):
"""Verify if dag.leaves returns the leaf tasks of a DAG."""
@@ -972,7 +937,7 @@ def test_leaves(self):
op5 = DummyOperator(task_id="t5")
[op1, op2] >> op3 >> [op4, op5]
- self.assertCountEqual(dag.leaves, [op4, op5])
+ assert set(dag.leaves) == {op4, op5}
def test_tree_view(self):
"""Verify correctness of dag.tree_view()."""
@@ -987,29 +952,29 @@ def test_tree_view(self):
stdout = stdout.getvalue()
stdout_lines = stdout.split("\n")
- self.assertIn('t1', stdout_lines[0])
- self.assertIn('t2', stdout_lines[1])
- self.assertIn('t3', stdout_lines[2])
+ assert 't1' in stdout_lines[0]
+ assert 't2' in stdout_lines[1]
+ assert 't3' in stdout_lines[2]
def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self):
"""Verify tasks with Duplicate task_id raises error"""
- with self.assertRaisesRegex(DuplicateTaskIdFound, "Task id 't1' has already been added to the DAG"):
+ with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"):
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
op1 = DummyOperator(task_id="t1")
op2 = BashOperator(task_id="t1", bash_command="sleep 1")
op1 >> op2
- self.assertEqual(dag.task_dict, {op1.task_id: op1})
+ assert dag.task_dict == {op1.task_id: op1}
def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self):
"""Verify tasks with Duplicate task_id raises error"""
- with self.assertRaisesRegex(DuplicateTaskIdFound, "Task id 't1' has already been added to the DAG"):
+ with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"):
dag = DAG("test_dag", start_date=DEFAULT_DATE)
op1 = DummyOperator(task_id="t1", dag=dag)
op2 = DummyOperator(task_id="t1", dag=dag)
op1 >> op2
- self.assertEqual(dag.task_dict, {op1.task_id: op1})
+ assert dag.task_dict == {op1.task_id: op1}
def test_duplicate_task_ids_for_same_task_is_allowed(self):
"""Verify that same tasks with Duplicate task_id do not raise error"""
@@ -1019,9 +984,9 @@ def test_duplicate_task_ids_for_same_task_is_allowed(self):
op1 >> op3
op2 >> op3
- self.assertEqual(op1, op2)
- self.assertEqual(dag.task_dict, {op1.task_id: op1, op3.task_id: op3})
- self.assertEqual(dag.task_dict, {op2.task_id: op2, op3.task_id: op3})
+ assert op1 == op2
+ assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3}
+ assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3}
def test_sub_dag_updates_all_references_while_deepcopy(self):
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
@@ -1032,7 +997,7 @@ def test_sub_dag_updates_all_references_while_deepcopy(self):
op2 >> op3
sub_dag = dag.sub_dag('t2', include_upstream=True, include_downstream=False)
- self.assertEqual(id(sub_dag.task_dict['t1'].downstream_list[0].dag), id(sub_dag))
+ assert id(sub_dag.task_dict['t1'].downstream_list[0].dag) == id(sub_dag)
def test_schedule_dag_no_previous_runs(self):
"""
@@ -1047,17 +1012,15 @@ def test_schedule_dag_no_previous_runs(self):
execution_date=TEST_DATE,
state=State.RUNNING,
)
- self.assertIsNotNone(dag_run)
- self.assertEqual(dag.dag_id, dag_run.dag_id)
- self.assertIsNotNone(dag_run.run_id)
- self.assertNotEqual('', dag_run.run_id)
- self.assertEqual(
- TEST_DATE,
- dag_run.execution_date,
- msg=f'dag_run.execution_date did not match expectation: {dag_run.execution_date}',
- )
- self.assertEqual(State.RUNNING, dag_run.state)
- self.assertFalse(dag_run.external_trigger)
+ assert dag_run is not None
+ assert dag.dag_id == dag_run.dag_id
+ assert dag_run.run_id is not None
+ assert '' != dag_run.run_id
+ assert (
+ TEST_DATE == dag_run.execution_date
+ ), f'dag_run.execution_date did not match expectation: {dag_run.execution_date}'
+ assert State.RUNNING == dag_run.state
+ assert not dag_run.external_trigger
dag.clear()
self._clean_up(dag_id)
@@ -1123,7 +1086,7 @@ def test_schedule_dag_once(self):
dag_id = "test_schedule_dag_once"
dag = DAG(dag_id=dag_id)
dag.schedule_interval = '@once'
- self.assertEqual(dag.normalized_schedule_interval, None)
+ assert dag.normalized_schedule_interval is None
dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=TEST_DATE))
# Sync once to create the DagModel
@@ -1161,8 +1124,8 @@ def test_fractional_seconds(self):
run.refresh_from_db()
- self.assertEqual(start_date, run.execution_date, "dag run execution_date loses precision")
- self.assertEqual(start_date, run.start_date, "dag run start_date loses precision ")
+ assert start_date == run.execution_date, "dag run execution_date loses precision"
+ assert start_date == run.start_date, "dag run start_date loses precision "
self._clean_up(dag_id)
def test_pickling(self):
@@ -1170,7 +1133,7 @@ def test_pickling(self):
args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
dag = DAG(test_dag_id, default_args=args)
dag_pickle = dag.pickle()
- self.assertEqual(dag_pickle.pickle.dag_id, dag.dag_id)
+ assert dag_pickle.pickle.dag_id == dag.dag_id
def test_rich_comparison_ops(self):
test_dag_id = 'test_rich_comparison_ops'
@@ -1193,42 +1156,42 @@ class DAGsubclass(DAG):
dag_.last_loaded = dag.last_loaded
# test identity equality
- self.assertEqual(dag, dag)
+ assert dag == dag # pylint: disable=comparison-with-itself
# test dag (in)equality based on _comps
- self.assertEqual(dag_eq, dag)
- self.assertNotEqual(dag_diff_name, dag)
- self.assertNotEqual(dag_diff_load_time, dag)
+ assert dag_eq == dag
+ assert dag_diff_name != dag
+ assert dag_diff_load_time != dag
# test dag inequality based on type even if _comps happen to match
- self.assertNotEqual(dag_subclass, dag)
+ assert dag_subclass != dag
# a dag should equal an unpickled version of itself
dump = pickle.dumps(dag)
- self.assertEqual(pickle.loads(dump), dag)
+ assert pickle.loads(dump) == dag
# dags are ordered based on dag_id no matter what the type is
- self.assertLess(dag, dag_diff_name)
- self.assertGreater(dag, dag_diff_load_time)
- self.assertLess(dag, dag_subclass_diff_name)
+ assert dag < dag_diff_name
+ assert dag > dag_diff_load_time
+ assert dag < dag_subclass_diff_name
# greater than should have been created automatically by functools
- self.assertGreater(dag_diff_name, dag)
+ assert dag_diff_name > dag
# hashes are non-random and match equality
- self.assertEqual(hash(dag), hash(dag))
- self.assertEqual(hash(dag_eq), hash(dag))
- self.assertNotEqual(hash(dag_diff_name), hash(dag))
- self.assertNotEqual(hash(dag_subclass), hash(dag))
+ assert hash(dag) == hash(dag)
+ assert hash(dag_eq) == hash(dag)
+ assert hash(dag_diff_name) != hash(dag)
+ assert hash(dag_subclass) != hash(dag)
def test_get_paused_dag_ids(self):
dag_id = "test_get_paused_dag_ids"
dag = DAG(dag_id, is_paused_upon_creation=True)
dag.sync_to_db()
- self.assertIsNotNone(DagModel.get_dagmodel(dag_id))
+ assert DagModel.get_dagmodel(dag_id) is not None
paused_dag_ids = DagModel.get_paused_dag_ids([dag_id])
- self.assertEqual(paused_dag_ids, {dag_id})
+ assert paused_dag_ids == {dag_id}
with create_session() as session:
session.query(DagModel).filter(DagModel.dag_id == dag_id).delete(synchronize_session=False)
@@ -1248,8 +1211,8 @@ def test_get_paused_dag_ids(self):
def test_normalized_schedule_interval(self, schedule_interval, expected_n_schedule_interval):
dag = DAG("test_schedule_interval", schedule_interval=schedule_interval)
- self.assertEqual(dag.normalized_schedule_interval, expected_n_schedule_interval)
- self.assertEqual(dag.schedule_interval, schedule_interval)
+ assert dag.normalized_schedule_interval == expected_n_schedule_interval
+ assert dag.schedule_interval == schedule_interval
def test_set_dag_runs_state(self):
clear_db_runs()
@@ -1331,9 +1294,9 @@ def test_clear_set_dagrun_state(self, dag_run_state):
.all()
)
- self.assertEqual(len(dagruns), 1)
+ assert len(dagruns) == 1
dagrun = dagruns[0] # type: DagRun
- self.assertEqual(dagrun.state, dag_run_state)
+ assert dagrun.state == dag_run_state
@parameterized.expand(
[(state, State.NONE) for state in State.task_states if state != State.RUNNING]
@@ -1376,9 +1339,9 @@ def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]):
.all()
)
- self.assertEqual(len(task_instances), 1)
+ assert len(task_instances) == 1
task_instance = task_instances[0] # type: TI
- self.assertEqual(task_instance.state, ti_state_end)
+ assert task_instance.state == ti_state_end
self._clean_up(dag_id)
def test_next_dagrun_after_date_once(self):
@@ -1615,11 +1578,11 @@ def test_replace_outdated_access_control_actions(self):
with pytest.warns(DeprecationWarning):
dag = DAG(dag_id='dag_with_outdated_perms', access_control=outdated_permissions)
- self.assertEqual(dag.access_control, updated_permissions)
+ assert dag.access_control == updated_permissions
with pytest.warns(DeprecationWarning):
dag.access_control = outdated_permissions
- self.assertEqual(dag.access_control, updated_permissions)
+ assert dag.access_control == updated_permissions
class TestDagModel:
diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py
index 73b0bf0e41334..6c6b1cb9dddcb 100644
--- a/tests/models/test_dagbag.py
+++ b/tests/models/test_dagbag.py
@@ -70,10 +70,10 @@ def test_get_existing_dag(self):
for dag_id in some_expected_dag_ids:
dag = dagbag.get_dag(dag_id)
- self.assertIsNotNone(dag)
- self.assertEqual(dag_id, dag.dag_id)
+ assert dag is not None
+ assert dag_id == dag.dag_id
- self.assertGreaterEqual(dagbag.size(), 7)
+ assert dagbag.size() >= 7
def test_get_non_existing_dag(self):
"""
@@ -82,7 +82,7 @@ def test_get_non_existing_dag(self):
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
non_existing_dag_id = "non_existing_dag_id"
- self.assertIsNone(dagbag.get_dag(non_existing_dag_id))
+ assert dagbag.get_dag(non_existing_dag_id) is None
def test_dont_load_example(self):
"""
@@ -90,7 +90,7 @@ def test_dont_load_example(self):
"""
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
- self.assertEqual(dagbag.size(), 0)
+ assert dagbag.size() == 0
def test_safe_mode_heuristic_match(self):
"""With safe mode enabled, a file matching the discovery heuristics
@@ -104,8 +104,8 @@ def test_safe_mode_heuristic_match(self):
with conf_vars({('core', 'dags_folder'): self.empty_dir}):
dagbag = models.DagBag(include_examples=False, safe_mode=True)
- self.assertEqual(len(dagbag.dagbag_stats), 1)
- self.assertEqual(dagbag.dagbag_stats[0].file, "/{}".format(os.path.basename(f.name)))
+ assert len(dagbag.dagbag_stats) == 1
+ assert dagbag.dagbag_stats[0].file == "/{}".format(os.path.basename(f.name))
def test_safe_mode_heuristic_mismatch(self):
"""With safe mode enabled, a file not matching the discovery heuristics
@@ -114,15 +114,15 @@ def test_safe_mode_heuristic_mismatch(self):
with NamedTemporaryFile(dir=self.empty_dir, suffix=".py"):
with conf_vars({('core', 'dags_folder'): self.empty_dir}):
dagbag = models.DagBag(include_examples=False, safe_mode=True)
- self.assertEqual(len(dagbag.dagbag_stats), 0)
+ assert len(dagbag.dagbag_stats) == 0
def test_safe_mode_disabled(self):
"""With safe mode disabled, an empty python file should be discovered."""
with NamedTemporaryFile(dir=self.empty_dir, suffix=".py") as f:
with conf_vars({('core', 'dags_folder'): self.empty_dir}):
dagbag = models.DagBag(include_examples=False, safe_mode=False)
- self.assertEqual(len(dagbag.dagbag_stats), 1)
- self.assertEqual(dagbag.dagbag_stats[0].file, "/{}".format(os.path.basename(f.name)))
+ assert len(dagbag.dagbag_stats) == 1
+ assert dagbag.dagbag_stats[0].file == "/{}".format(os.path.basename(f.name))
def test_process_file_that_contains_multi_bytes_char(self):
"""
@@ -133,7 +133,7 @@ def test_process_file_that_contains_multi_bytes_char(self):
f.flush()
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
- self.assertEqual([], dagbag.process_file(f.name))
+ assert [] == dagbag.process_file(f.name)
def test_zip_skip_log(self):
"""
@@ -144,11 +144,10 @@ def test_zip_skip_log(self):
test_zip_path = os.path.join(TEST_DAGS_FOLDER, "test_zip.zip")
dagbag = models.DagBag(dag_folder=test_zip_path, include_examples=False)
- self.assertTrue(dagbag.has_logged)
- self.assertIn(
+ assert dagbag.has_logged
+ assert (
f'INFO:airflow.models.dagbag.DagBag:File {test_zip_path}:file_no_airflow_dag.py '
- 'assumed to contain no DAGs. Skipping.',
- cm.output,
+ 'assumed to contain no DAGs. Skipping.' in cm.output
)
def test_zip(self):
@@ -157,7 +156,7 @@ def test_zip(self):
"""
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip"))
- self.assertTrue(dagbag.get_dag("test_zip_dag"))
+ assert dagbag.get_dag("test_zip_dag")
def test_process_file_cron_validity_check(self):
"""
@@ -167,11 +166,11 @@ def test_process_file_cron_validity_check(self):
invalid_dag_files = ["test_invalid_cron.py", "test_zip_invalid_cron.zip"]
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
- self.assertEqual(len(dagbag.import_errors), 0)
+ assert len(dagbag.import_errors) == 0
for file in invalid_dag_files:
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, file))
- self.assertEqual(len(dagbag.import_errors), len(invalid_dag_files))
- self.assertEqual(len(dagbag.dags), 0)
+ assert len(dagbag.import_errors) == len(invalid_dag_files)
+ assert len(dagbag.dags) == 0
@patch.object(DagModel, 'get_current')
def test_get_dag_without_refresh(self, mock_dagmodel):
@@ -197,9 +196,9 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True):
dagbag.process_file_calls
# Should not call process_file again, since it's already loaded during init.
- self.assertEqual(1, dagbag.process_file_calls)
- self.assertIsNotNone(dagbag.get_dag(dag_id))
- self.assertEqual(1, dagbag.process_file_calls)
+ assert 1 == dagbag.process_file_calls
+ assert dagbag.get_dag(dag_id) is not None
+ assert 1 == dagbag.process_file_calls
def test_get_dag_fileloc(self):
"""
@@ -218,7 +217,7 @@ def test_get_dag_fileloc(self):
for dag_id, path in expected.items():
dag = dagbag.get_dag(dag_id)
- self.assertTrue(dag.fileloc.endswith(path))
+ assert dag.fileloc.endswith(path)
@patch.object(DagModel, "get_current")
def test_refresh_py_dag(self, mock_dagmodel):
@@ -244,11 +243,11 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True):
dagbag = _TestDagBag(dag_folder=self.empty_dir, include_examples=True)
- self.assertEqual(1, dagbag.process_file_calls)
+ assert 1 == dagbag.process_file_calls
dag = dagbag.get_dag(dag_id)
- self.assertIsNotNone(dag)
- self.assertEqual(dag_id, dag.dag_id)
- self.assertEqual(2, dagbag.process_file_calls)
+ assert dag is not None
+ assert dag_id == dag.dag_id
+ assert 2 == dagbag.process_file_calls
@patch.object(DagModel, "get_current")
def test_refresh_packaged_dag(self, mock_dagmodel):
@@ -272,11 +271,11 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True):
dagbag = _TestDagBag(dag_folder=os.path.realpath(TEST_DAGS_FOLDER), include_examples=False)
- self.assertEqual(1, dagbag.process_file_calls)
+ assert 1 == dagbag.process_file_calls
dag = dagbag.get_dag(dag_id)
- self.assertIsNotNone(dag)
- self.assertEqual(dag_id, dag.dag_id)
- self.assertEqual(2, dagbag.process_file_calls)
+ assert dag is not None
+ assert dag_id == dag.dag_id
+ assert 2 == dagbag.process_file_calls
def process_dag(self, create_dag):
"""
@@ -300,17 +299,19 @@ def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag, s
for dag_id in expected_dag_ids:
actual_dagbag.log.info('validating %s' % dag_id)
- self.assertEqual(
- dag_id in actual_found_dag_ids,
- should_be_found,
- 'dag "%s" should %shave been found after processing dag "%s"'
- % (dag_id, '' if should_be_found else 'not ', expected_parent_dag.dag_id),
+ assert (
+ dag_id in actual_found_dag_ids
+ ) == should_be_found, 'dag "{}" should {}have been found after processing dag "{}"'.format(
+ dag_id,
+ '' if should_be_found else 'not ',
+ expected_parent_dag.dag_id,
)
- self.assertEqual(
- dag_id in actual_dagbag.dags,
- should_be_found,
- 'dag "%s" should %sbe in dagbag.dags after processing dag "%s"'
- % (dag_id, '' if should_be_found else 'not ', expected_parent_dag.dag_id),
+ assert (
+ dag_id in actual_dagbag.dags
+ ) == should_be_found, 'dag "{}" should {}be in dagbag.dags after processing dag "{}"'.format(
+ dag_id,
+ '' if should_be_found else 'not ',
+ expected_parent_dag.dag_id,
)
def test_load_subdags(self):
@@ -356,7 +357,7 @@ def subdag_1():
test_dag = standard_subdag()
# sanity check to make sure DAG.subdag is still functioning properly
- self.assertEqual(len(test_dag.subdags), 2)
+ assert len(test_dag.subdags) == 2
# Perform processing dag
dagbag, found_dags, _ = self.process_dag(standard_subdag)
@@ -440,7 +441,7 @@ def subdag_1():
test_dag = nested_subdags()
# sanity check to make sure DAG.subdag is still functioning properly
- self.assertEqual(len(test_dag.subdags), 6)
+ assert len(test_dag.subdags) == 6
# Perform processing dag
dagbag, found_dags, _ = self.process_dag(nested_subdags)
@@ -475,7 +476,7 @@ def basic_cycle():
test_dag = basic_cycle()
# sanity check to make sure DAG.subdag is still functioning properly
- self.assertEqual(len(test_dag.subdags), 0)
+ assert len(test_dag.subdags) == 0
# Perform processing dag
dagbag, found_dags, file_path = self.process_dag(basic_cycle)
@@ -483,7 +484,7 @@ def basic_cycle():
# #Validate correctness
# None of the dags should be found
self.validate_dags(test_dag, found_dags, dagbag, should_be_found=False)
- self.assertIn(file_path, dagbag.import_errors)
+ assert file_path in dagbag.import_errors
# Define Dag to load
def nested_subdag_cycle():
@@ -562,7 +563,7 @@ def subdag_1():
test_dag = nested_subdag_cycle()
# sanity check to make sure DAG.subdag is still functioning properly
- self.assertEqual(len(test_dag.subdags), 6)
+ assert len(test_dag.subdags) == 6
# Perform processing dag
dagbag, found_dags, file_path = self.process_dag(nested_subdag_cycle)
@@ -570,7 +571,7 @@ def subdag_1():
# Validate correctness
# None of the dags should be found
self.validate_dags(test_dag, found_dags, dagbag, should_be_found=False)
- self.assertIn(file_path, dagbag.import_errors)
+ assert file_path in dagbag.import_errors
def test_process_file_with_none(self):
"""
@@ -578,7 +579,7 @@ def test_process_file_with_none(self):
"""
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
- self.assertEqual([], dagbag.process_file(None))
+ assert [] == dagbag.process_file(None)
def test_deactivate_unknown_dags(self):
"""
@@ -596,8 +597,8 @@ def test_deactivate_unknown_dags(self):
models.DAG.deactivate_unknown_dags(expected_active_dags)
after_model = DagModel.get_dagmodel(dag_id)
- self.assertTrue(model_before.is_active)
- self.assertFalse(after_model.is_active)
+ assert model_before.is_active
+ assert not after_model.is_active
# clean up
with create_session() as session:
@@ -610,7 +611,7 @@ def test_serialized_dags_are_written_to_db_on_sync(self):
"""
with create_session() as session:
serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar()
- self.assertEqual(serialized_dags_count, 0)
+ assert serialized_dags_count == 0
dagbag = DagBag(
dag_folder=os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py"),
@@ -618,10 +619,10 @@ def test_serialized_dags_are_written_to_db_on_sync(self):
)
dagbag.sync_to_db()
- self.assertFalse(dagbag.read_dags_from_db)
+ assert not dagbag.read_dags_from_db
new_serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar()
- self.assertEqual(new_serialized_dags_count, 1)
+ assert new_serialized_dags_count == 1
@patch("airflow.models.serialized_dag.SerializedDagModel.write_dag")
def test_serialized_dag_errors_are_import_errors(self, mock_serialize):
@@ -701,14 +702,14 @@ def test_get_dag_with_dag_serialization(self):
dag_bag = DagBag(read_dags_from_db=True)
ser_dag_1 = dag_bag.get_dag("example_bash_operator")
ser_dag_1_update_time = dag_bag.dags_last_fetched["example_bash_operator"]
- self.assertEqual(example_bash_op_dag.tags, ser_dag_1.tags)
- self.assertEqual(ser_dag_1_update_time, tz.datetime(2020, 1, 5, 0, 0, 0))
+ assert example_bash_op_dag.tags == ser_dag_1.tags
+ assert ser_dag_1_update_time == tz.datetime(2020, 1, 5, 0, 0, 0)
# Check that if min_serialized_dag_fetch_interval has not passed we do not fetch the DAG
# from DB
with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 4)):
with assert_queries_count(0):
- self.assertEqual(dag_bag.get_dag("example_bash_operator").tags, ["example", "example2"])
+ assert dag_bag.get_dag("example_bash_operator").tags == ["example", "example2"]
# Make a change in the DAG and write Serialized DAG to the DB
with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 6)):
@@ -722,8 +723,8 @@ def test_get_dag_with_dag_serialization(self):
updated_ser_dag_1 = dag_bag.get_dag("example_bash_operator")
updated_ser_dag_1_update_time = dag_bag.dags_last_fetched["example_bash_operator"]
- self.assertCountEqual(updated_ser_dag_1.tags, ["example", "example2", "new_tag"])
- self.assertGreater(updated_ser_dag_1_update_time, ser_dag_1_update_time)
+ assert set(updated_ser_dag_1.tags) == {"example", "example2", "new_tag"}
+ assert updated_ser_dag_1_update_time > ser_dag_1_update_time
def test_collect_dags_from_db(self):
"""DAGs are collected from Database"""
@@ -735,15 +736,15 @@ def test_collect_dags_from_db(self):
SerializedDagModel.write_dag(dag)
new_dagbag = DagBag(read_dags_from_db=True)
- self.assertEqual(len(new_dagbag.dags), 0)
+ assert len(new_dagbag.dags) == 0
new_dagbag.collect_dags_from_db()
new_dags = new_dagbag.dags
- self.assertEqual(len(example_dags), len(new_dags))
+ assert len(example_dags) == len(new_dags)
for dag_id, dag in example_dags.items():
serialized_dag = new_dags[dag_id]
- self.assertEqual(serialized_dag.dag_id, dag.dag_id)
- self.assertEqual(set(serialized_dag.task_dict), set(dag.task_dict))
+ assert serialized_dag.dag_id == dag.dag_id
+ assert set(serialized_dag.task_dict) == set(dag.task_dict)
@patch("airflow.settings.task_policy", cluster_policies.cluster_policy)
def test_task_cluster_policy_violation(self):
@@ -754,7 +755,7 @@ def test_task_cluster_policy_violation(self):
dag_file = os.path.join(TEST_DAGS_FOLDER, "test_missing_owner.py")
dagbag = DagBag(dag_folder=dag_file, include_smart_sensor=False, include_examples=False)
- self.assertEqual(set(), set(dagbag.dag_ids))
+ assert set() == set(dagbag.dag_ids)
expected_import_errors = {
dag_file: (
f"""DAG policy violation (DAG ID: test_missing_owner, Path: {dag_file}):\n"""
@@ -762,7 +763,7 @@ def test_task_cluster_policy_violation(self):
""" * Task must have non-None non-default owner. Current value: airflow"""
)
}
- self.assertEqual(expected_import_errors, dagbag.import_errors)
+ assert expected_import_errors == dagbag.import_errors
@patch("airflow.settings.task_policy", cluster_policies.cluster_policy)
def test_task_cluster_policy_obeyed(self):
@@ -773,9 +774,9 @@ def test_task_cluster_policy_obeyed(self):
dag_file = os.path.join(TEST_DAGS_FOLDER, "test_with_non_default_owner.py")
dagbag = DagBag(dag_folder=dag_file, include_examples=False, include_smart_sensor=False)
- self.assertEqual({"test_with_non_default_owner"}, set(dagbag.dag_ids))
+ assert {"test_with_non_default_owner"} == set(dagbag.dag_ids)
- self.assertEqual({}, dagbag.import_errors)
+ assert {} == dagbag.import_errors
@patch("airflow.settings.dag_policy", cluster_policies.dag_policy)
def test_dag_cluster_policy_obeyed(self):
diff --git a/tests/models/test_dagcode.py b/tests/models/test_dagcode.py
index ee2eb2aaa61d4..e1dcfe6dfc6dd 100644
--- a/tests/models/test_dagcode.py
+++ b/tests/models/test_dagcode.py
@@ -19,6 +19,8 @@
from datetime import timedelta
from unittest.mock import patch
+import pytest
+
from airflow import AirflowException, example_dags as example_dags_module
from airflow.models import DagBag
from airflow.models.dagcode import DagCode
@@ -98,7 +100,7 @@ def test_detecting_duplicate_key(self, mock_hash):
"""Dag code detects duplicate key."""
mock_hash.return_value = 0
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self._write_two_example_dags()
def _compare_example_dags(self, example_dags):
@@ -106,7 +108,7 @@ def _compare_example_dags(self, example_dags):
for dag in example_dags.values():
if dag.is_subdag:
dag.fileloc = dag.parent_dag.fileloc
- self.assertTrue(DagCode.has_dag(dag.fileloc))
+ assert DagCode.has_dag(dag.fileloc)
dag_fileloc_hash = DagCode.dag_fileloc_hash(dag.fileloc)
result = (
session.query(DagCode.fileloc, DagCode.fileloc_hash, DagCode.source_code)
@@ -115,10 +117,10 @@ def _compare_example_dags(self, example_dags):
.one()
)
- self.assertEqual(result.fileloc, dag.fileloc)
+ assert result.fileloc == dag.fileloc
with open_maybe_zipped(dag.fileloc, 'r') as source:
source_code = source.read()
- self.assertEqual(result.source_code, source_code)
+ assert result.source_code == source_code
@conf_vars({('core', 'store_dag_code'): 'True'})
@patch("airflow.models.dag.settings.STORE_DAG_CODE", True)
@@ -137,7 +139,7 @@ def test_code_can_be_read_when_no_access_to_file(self):
dag_code = DagCode.get_code_by_fileloc(example_dag.fileloc)
for test_string in ['example_bash_operator', 'also_run_this', 'run_this_last']:
- self.assertIn(test_string, dag_code)
+ assert test_string in dag_code
@conf_vars({('core', 'store_dag_code'): 'True'})
@patch("airflow.models.dag.settings.STORE_DAG_CODE", True)
@@ -149,8 +151,8 @@ def test_db_code_updated_on_dag_file_change(self):
with create_session() as session:
result = session.query(DagCode).filter(DagCode.fileloc == example_dag.fileloc).one()
- self.assertEqual(result.fileloc, example_dag.fileloc)
- self.assertIsNotNone(result.source_code)
+ assert result.fileloc == example_dag.fileloc
+ assert result.source_code is not None
with patch('airflow.models.dagcode.os.path.getmtime') as mock_mtime:
mock_mtime.return_value = (result.last_updated + timedelta(seconds=1)).timestamp()
@@ -162,6 +164,6 @@ def test_db_code_updated_on_dag_file_change(self):
with create_session() as session:
new_result = session.query(DagCode).filter(DagCode.fileloc == example_dag.fileloc).one()
- self.assertEqual(new_result.fileloc, example_dag.fileloc)
- self.assertEqual(new_result.source_code, "# dummy code")
- self.assertGreater(new_result.last_updated, result.last_updated)
+ assert new_result.fileloc == example_dag.fileloc
+ assert new_result.source_code == "# dummy code"
+ assert new_result.last_updated > result.last_updated
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index f27c50aa61c8c..ad2bbcbfd6be8 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -102,7 +102,7 @@ def test_clear_task_instances_for_backfill_dagrun(self):
session.commit()
ti0.refresh_from_db()
dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.execution_date == now).first()
- self.assertEqual(dr0.state, State.RUNNING)
+ assert dr0.state == State.RUNNING
def test_dagrun_find(self):
session = settings.Session()
@@ -132,10 +132,10 @@ def test_dagrun_find(self):
session.commit()
- self.assertEqual(1, len(models.DagRun.find(dag_id=dag_id1, external_trigger=True)))
- self.assertEqual(0, len(models.DagRun.find(dag_id=dag_id1, external_trigger=False)))
- self.assertEqual(0, len(models.DagRun.find(dag_id=dag_id2, external_trigger=True)))
- self.assertEqual(1, len(models.DagRun.find(dag_id=dag_id2, external_trigger=False)))
+ assert 1 == len(models.DagRun.find(dag_id=dag_id1, external_trigger=True))
+ assert 0 == len(models.DagRun.find(dag_id=dag_id1, external_trigger=False))
+ assert 0 == len(models.DagRun.find(dag_id=dag_id2, external_trigger=True))
+ assert 1 == len(models.DagRun.find(dag_id=dag_id2, external_trigger=False))
def test_dagrun_success_when_all_skipped(self):
"""
@@ -158,7 +158,7 @@ def test_dagrun_success_when_all_skipped(self):
dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states)
dag_run.update_state()
- self.assertEqual(State.SUCCESS, dag_run.state)
+ assert State.SUCCESS == dag_run.state
def test_dagrun_success_conditions(self):
session = settings.Session()
@@ -193,14 +193,14 @@ def test_dagrun_success_conditions(self):
# root is successful, but unfinished tasks
dr.update_state()
- self.assertEqual(State.RUNNING, dr.state)
+ assert State.RUNNING == dr.state
# one has failed, but root is successful
ti_op2.set_state(state=State.FAILED, session=session)
ti_op3.set_state(state=State.SUCCESS, session=session)
ti_op4.set_state(state=State.SUCCESS, session=session)
dr.update_state()
- self.assertEqual(State.SUCCESS, dr.state)
+ assert State.SUCCESS == dr.state
def test_dagrun_deadlock(self):
session = settings.Session()
@@ -224,12 +224,12 @@ def test_dagrun_deadlock(self):
ti_op2.set_state(state=State.NONE, session=session)
dr.update_state()
- self.assertEqual(dr.state, State.RUNNING)
+ assert dr.state == State.RUNNING
ti_op2.set_state(state=State.NONE, session=session)
op2.trigger_rule = 'invalid'
dr.update_state()
- self.assertEqual(dr.state, State.FAILED)
+ assert dr.state == State.FAILED
def test_dagrun_no_deadlock_with_shutdown(self):
session = settings.Session()
@@ -249,7 +249,7 @@ def test_dagrun_no_deadlock_with_shutdown(self):
upstream_ti.set_state(State.SHUTDOWN, session=session)
dr.update_state()
- self.assertEqual(dr.state, State.RUNNING)
+ assert dr.state == State.RUNNING
def test_dagrun_no_deadlock_with_depends_on_past(self):
session = settings.Session()
@@ -278,18 +278,18 @@ def test_dagrun_no_deadlock_with_depends_on_past(self):
ti1_op1.set_state(state=State.RUNNING, session=session)
dr.update_state()
dr2.update_state()
- self.assertEqual(dr.state, State.RUNNING)
- self.assertEqual(dr2.state, State.RUNNING)
+ assert dr.state == State.RUNNING
+ assert dr2.state == State.RUNNING
ti2_op1.set_state(state=State.RUNNING, session=session)
dr.update_state()
dr2.update_state()
- self.assertEqual(dr.state, State.RUNNING)
- self.assertEqual(dr2.state, State.RUNNING)
+ assert dr.state == State.RUNNING
+ assert dr2.state == State.RUNNING
def test_dagrun_success_callback(self):
def on_success_callable(context):
- self.assertEqual(context['dag_run'].dag_id, 'test_dagrun_success_callback')
+ assert context['dag_run'].dag_id == 'test_dagrun_success_callback'
dag = DAG(
dag_id='test_dagrun_success_callback',
@@ -310,13 +310,13 @@ def on_success_callable(context):
dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states)
_, callback = dag_run.update_state()
- self.assertEqual(State.SUCCESS, dag_run.state)
+ assert State.SUCCESS == dag_run.state
# Callbacks are not added until handle_callback = False is passed to dag_run.update_state()
- self.assertIsNone(callback)
+ assert callback is None
def test_dagrun_failure_callback(self):
def on_failure_callable(context):
- self.assertEqual(context['dag_run'].dag_id, 'test_dagrun_failure_callback')
+ assert context['dag_run'].dag_id == 'test_dagrun_failure_callback'
dag = DAG(
dag_id='test_dagrun_failure_callback',
@@ -337,15 +337,13 @@ def on_failure_callable(context):
dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states)
_, callback = dag_run.update_state()
- self.assertEqual(State.FAILED, dag_run.state)
+ assert State.FAILED == dag_run.state
# Callbacks are not added until handle_callback = False is passed to dag_run.update_state()
- self.assertIsNone(callback)
+ assert callback is None
def test_dagrun_update_state_with_handle_callback_success(self):
def on_success_callable(context):
- self.assertEqual(
- context['dag_run'].dag_id, 'test_dagrun_update_state_with_handle_callback_success'
- )
+ assert context['dag_run'].dag_id == 'test_dagrun_update_state_with_handle_callback_success'
dag = DAG(
dag_id='test_dagrun_update_state_with_handle_callback_success',
@@ -367,7 +365,7 @@ def on_success_callable(context):
dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states)
_, callback = dag_run.update_state(execute_callbacks=False)
- self.assertEqual(State.SUCCESS, dag_run.state)
+ assert State.SUCCESS == dag_run.state
# Callbacks are not added until handle_callback = False is passed to dag_run.update_state()
assert callback == DagCallbackRequest(
@@ -380,9 +378,7 @@ def on_success_callable(context):
def test_dagrun_update_state_with_handle_callback_failure(self):
def on_failure_callable(context):
- self.assertEqual(
- context['dag_run'].dag_id, 'test_dagrun_update_state_with_handle_callback_failure'
- )
+ assert context['dag_run'].dag_id == 'test_dagrun_update_state_with_handle_callback_failure'
dag = DAG(
dag_id='test_dagrun_update_state_with_handle_callback_failure',
@@ -404,7 +400,7 @@ def on_failure_callable(context):
dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states)
_, callback = dag_run.update_state(execute_callbacks=False)
- self.assertEqual(State.FAILED, dag_run.state)
+ assert State.FAILED == dag_run.state
# Callbacks are not added until handle_callback = False is passed to dag_run.update_state()
assert callback == DagCallbackRequest(
@@ -432,15 +428,15 @@ def test_dagrun_set_state_end_date(self):
# State.RUNNING set end_date back to NULL
session.add(dr)
session.commit()
- self.assertIsNone(dr.end_date)
+ assert dr.end_date is None
dr.set_state(State.SUCCESS)
session.merge(dr)
session.commit()
dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_set_state_end_date').one()
- self.assertIsNotNone(dr_database.end_date)
- self.assertEqual(dr.end_date, dr_database.end_date)
+ assert dr_database.end_date is not None
+ assert dr.end_date == dr_database.end_date
dr.set_state(State.RUNNING)
session.merge(dr)
@@ -448,15 +444,15 @@ def test_dagrun_set_state_end_date(self):
dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_set_state_end_date').one()
- self.assertIsNone(dr_database.end_date)
+ assert dr_database.end_date is None
dr.set_state(State.FAILED)
session.merge(dr)
session.commit()
dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_set_state_end_date').one()
- self.assertIsNotNone(dr_database.end_date)
- self.assertEqual(dr.end_date, dr_database.end_date)
+ assert dr_database.end_date is not None
+ assert dr.end_date == dr_database.end_date
def test_dagrun_update_state_end_date(self):
session = settings.Session()
@@ -486,7 +482,7 @@ def test_dagrun_update_state_end_date(self):
# State.RUNNING set end_date back to NULL
session.merge(dr)
session.commit()
- self.assertIsNone(dr.end_date)
+ assert dr.end_date is None
ti_op1 = dr.get_task_instance(task_id=op1.task_id)
ti_op1.set_state(state=State.SUCCESS, session=session)
@@ -496,8 +492,8 @@ def test_dagrun_update_state_end_date(self):
dr.update_state()
dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one()
- self.assertIsNotNone(dr_database.end_date)
- self.assertEqual(dr.end_date, dr_database.end_date)
+ assert dr_database.end_date is not None
+ assert dr.end_date == dr_database.end_date
ti_op1.set_state(state=State.RUNNING, session=session)
ti_op2.set_state(state=State.RUNNING, session=session)
@@ -505,9 +501,9 @@ def test_dagrun_update_state_end_date(self):
dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one()
- self.assertEqual(dr._state, State.RUNNING)
- self.assertIsNone(dr.end_date)
- self.assertIsNone(dr_database.end_date)
+ assert dr._state == State.RUNNING
+ assert dr.end_date is None
+ assert dr_database.end_date is None
ti_op1.set_state(state=State.FAILED, session=session)
ti_op2.set_state(state=State.FAILED, session=session)
@@ -515,8 +511,8 @@ def test_dagrun_update_state_end_date(self):
dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one()
- self.assertIsNotNone(dr_database.end_date)
- self.assertEqual(dr.end_date, dr_database.end_date)
+ assert dr_database.end_date is not None
+ assert dr.end_date == dr_database.end_date
def test_get_task_instance_on_empty_dagrun(self):
"""
@@ -543,7 +539,7 @@ def test_get_task_instance_on_empty_dagrun(self):
session.commit()
ti = dag_run.get_task_instance('test_short_circuit_false')
- self.assertEqual(None, ti)
+ assert ti is None
def test_get_latest_runs(self):
session = settings.Session()
@@ -554,7 +550,7 @@ def test_get_latest_runs(self):
session.close()
for dagrun in dagruns:
if dagrun.dag_id == 'test_latest_runs_1':
- self.assertEqual(dagrun.execution_date, timezone.datetime(2015, 1, 2))
+ assert dagrun.execution_date == timezone.datetime(2015, 1, 2)
def test_is_backfill(self):
dag = DAG(dag_id='test_is_backfill', start_date=DEFAULT_DATE)
@@ -567,9 +563,9 @@ def test_is_backfill(self):
dagrun3 = self.create_dag_run(dag, execution_date=DEFAULT_DATE + datetime.timedelta(days=2))
dagrun3.run_id = None
- self.assertTrue(dagrun.is_backfill)
- self.assertFalse(dagrun2.is_backfill)
- self.assertFalse(dagrun3.is_backfill)
+ assert dagrun.is_backfill
+ assert not dagrun2.is_backfill
+ assert not dagrun3.is_backfill
def test_removed_task_instances_can_be_restored(self):
def with_all_tasks_removed(dag):
@@ -580,20 +576,20 @@ def with_all_tasks_removed(dag):
dagrun = self.create_dag_run(dag)
flaky_ti = dagrun.get_task_instances()[0]
- self.assertEqual('flaky_task', flaky_ti.task_id)
- self.assertEqual(State.NONE, flaky_ti.state)
+ assert 'flaky_task' == flaky_ti.task_id
+ assert State.NONE == flaky_ti.state
dagrun.dag = with_all_tasks_removed(dag)
dagrun.verify_integrity()
flaky_ti.refresh_from_db()
- self.assertEqual(State.NONE, flaky_ti.state)
+ assert State.NONE == flaky_ti.state
dagrun.dag.add_task(DummyOperator(task_id='flaky_task', owner='test'))
dagrun.verify_integrity()
flaky_ti.refresh_from_db()
- self.assertEqual(State.NONE, flaky_ti.state)
+ assert State.NONE == flaky_ti.state
def test_already_added_task_instances_can_be_ignored(self):
dag = DAG('triggered_dag', start_date=DEFAULT_DATE)
@@ -601,8 +597,8 @@ def test_already_added_task_instances_can_be_ignored(self):
dagrun = self.create_dag_run(dag)
first_ti = dagrun.get_task_instances()[0]
- self.assertEqual('first_task', first_ti.task_id)
- self.assertEqual(State.NONE, first_ti.state)
+ assert 'first_task' == first_ti.task_id
+ assert State.NONE == first_ti.state
# Lets assume that the above TI was added into DB by webserver, but if scheduler
# is running the same method at the same time it would find 0 TIs for this dag
@@ -612,7 +608,7 @@ def test_already_added_task_instances_can_be_ignored(self):
mock_gtis.return_value = []
dagrun.verify_integrity()
first_ti.refresh_from_db()
- self.assertEqual(State.NONE, first_ti.state)
+ assert State.NONE == first_ti.state
@parameterized.expand([(state,) for state in State.task_states])
@mock.patch('airflow.models.dagrun.task_instance_mutation_hook')
@@ -664,7 +660,7 @@ def test_depends_on_past(self, prev_ti_state, is_ti_success):
prev_ti.set_state(prev_ti_state)
ti.set_state(State.QUEUED)
ti.run()
- self.assertEqual(ti.state == State.SUCCESS, is_ti_success)
+ assert (ti.state == State.SUCCESS) == is_ti_success
@parameterized.expand(
[
@@ -689,12 +685,12 @@ def test_wait_for_downstream(self, prev_ti_state, is_ti_success):
ti = TI(task=upstream, execution_date=timezone.datetime(2016, 1, 2, 0, 0, 0))
prev_ti = ti.get_previous_ti()
prev_ti.set_state(State.SUCCESS)
- self.assertEqual(prev_ti.state, State.SUCCESS)
+ assert prev_ti.state == State.SUCCESS
prev_ti_downstream.set_state(prev_ti_state)
ti.set_state(State.QUEUED)
ti.run()
- self.assertEqual(ti.state == State.SUCCESS, is_ti_success)
+ assert (ti.state == State.SUCCESS) == is_ti_success
def test_next_dagruns_to_examine_only_unpaused(self):
"""
@@ -750,7 +746,7 @@ def test_no_scheduling_delay_for_nonscheduled_runs(self, stats_mock):
dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states)
dag_run.update_state()
- self.assertNotIn(call(f'dagrun.{dag.dag_id}.first_task_scheduling_delay'), stats_mock.mock_calls)
+ assert call(f'dagrun.{dag.dag_id}.first_task_scheduling_delay') not in stats_mock.mock_calls
@parameterized.expand(
[
@@ -822,8 +818,8 @@ def test_states_sets(self):
dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states)
ti_success = dag_run.get_task_instance(dag_task_success.task_id)
ti_failed = dag_run.get_task_instance(dag_task_failed.task_id)
- self.assertIn(ti_success.state, State.success_states)
- self.assertIn(ti_failed.state, State.failed_states)
+ assert ti_success.state in State.success_states
+ assert ti_failed.state in State.failed_states
def test_delete_dag_run_and_task_instance_does_not_raise_error(self):
clear_db_jobs()
diff --git a/tests/models/test_pool.py b/tests/models/test_pool.py
index 79609daac1f89..915d01dd078f9 100644
--- a/tests/models/test_pool.py
+++ b/tests/models/test_pool.py
@@ -59,27 +59,24 @@ def test_open_slots(self):
session.commit()
session.close()
- self.assertEqual(3, pool.open_slots()) # pylint: disable=no-value-for-parameter
- self.assertEqual(1, pool.running_slots()) # pylint: disable=no-value-for-parameter
- self.assertEqual(1, pool.queued_slots()) # pylint: disable=no-value-for-parameter
- self.assertEqual(2, pool.occupied_slots()) # pylint: disable=no-value-for-parameter
- self.assertEqual(
- {
- "default_pool": {
- "open": 128,
- "queued": 0,
- "total": 128,
- "running": 0,
- },
- "test_pool": {
- "open": 3,
- "queued": 1,
- "running": 1,
- "total": 5,
- },
+ assert 3 == pool.open_slots() # pylint: disable=no-value-for-parameter
+ assert 1 == pool.running_slots() # pylint: disable=no-value-for-parameter
+ assert 1 == pool.queued_slots() # pylint: disable=no-value-for-parameter
+ assert 2 == pool.occupied_slots() # pylint: disable=no-value-for-parameter
+ assert {
+ "default_pool": {
+ "open": 128,
+ "queued": 0,
+ "total": 128,
+ "running": 0,
},
- pool.slots_stats(),
- )
+ "test_pool": {
+ "open": 3,
+ "queued": 1,
+ "running": 1,
+ "total": 5,
+ },
+ } == pool.slots_stats()
def test_infinite_slots(self):
pool = Pool(pool='test_pool', slots=-1)
@@ -101,14 +98,14 @@ def test_infinite_slots(self):
session.commit()
session.close()
- self.assertEqual(float('inf'), pool.open_slots()) # pylint: disable=no-value-for-parameter
- self.assertEqual(1, pool.running_slots()) # pylint: disable=no-value-for-parameter
- self.assertEqual(1, pool.queued_slots()) # pylint: disable=no-value-for-parameter
- self.assertEqual(2, pool.occupied_slots()) # pylint: disable=no-value-for-parameter
+ assert float('inf') == pool.open_slots() # pylint: disable=no-value-for-parameter
+ assert 1 == pool.running_slots() # pylint: disable=no-value-for-parameter
+ assert 1 == pool.queued_slots() # pylint: disable=no-value-for-parameter
+ assert 2 == pool.occupied_slots() # pylint: disable=no-value-for-parameter
def test_default_pool_open_slots(self):
set_default_pool_slots(5)
- self.assertEqual(5, Pool.get_default_pool().open_slots())
+ assert 5 == Pool.get_default_pool().open_slots()
dag = DAG(
dag_id='test_default_pool_open_slots',
@@ -127,4 +124,4 @@ def test_default_pool_open_slots(self):
session.commit()
session.close()
- self.assertEqual(2, Pool.get_default_pool().open_slots())
+ assert 2 == Pool.get_default_pool().open_slots()
diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py
index 9482f08a23ef1..1cf4e3fac4e35 100644
--- a/tests/models/test_renderedtifields.py
+++ b/tests/models/test_renderedtifields.py
@@ -117,17 +117,15 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field):
ti = TI(task=task, execution_date=EXECUTION_DATE)
rtif = RTIF(ti=ti)
- self.assertEqual(ti.dag_id, rtif.dag_id)
- self.assertEqual(ti.task_id, rtif.task_id)
- self.assertEqual(ti.execution_date, rtif.execution_date)
- self.assertEqual(expected_rendered_field, rtif.rendered_fields.get("bash_command"))
+ assert ti.dag_id == rtif.dag_id
+ assert ti.task_id == rtif.task_id
+ assert ti.execution_date == rtif.execution_date
+ assert expected_rendered_field == rtif.rendered_fields.get("bash_command")
with create_session() as session:
session.add(rtif)
- self.assertEqual(
- {"bash_command": expected_rendered_field, "env": None}, RTIF.get_templated_fields(ti=ti)
- )
+ assert {"bash_command": expected_rendered_field, "env": None} == RTIF.get_templated_fields(ti=ti)
# Test the else part of get_templated_fields
# i.e. for the TIs that are not stored in RTIF table
@@ -136,7 +134,7 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field):
task_2 = BashOperator(task_id="test2", bash_command=templated_field)
ti2 = TI(task_2, EXECUTION_DATE)
- self.assertIsNone(RTIF.get_templated_fields(ti=ti2))
+ assert RTIF.get_templated_fields(ti=ti2) is None
@parameterized.expand(
[
@@ -169,15 +167,15 @@ def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expect
result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
for rtif in rtif_list:
- self.assertIn(rtif, result)
+ assert rtif in result
- self.assertEqual(rtif_num, len(result))
+ assert rtif_num == len(result)
# Verify old records are deleted and only 'num_to_keep' records are kept
with assert_queries_count(expected_query_count):
RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep)
result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
- self.assertEqual(remaining_rtifs, len(result))
+ assert remaining_rtifs == len(result)
def test_write(self):
"""
@@ -187,7 +185,7 @@ def test_write(self):
session = settings.Session()
result = session.query(RTIF).all()
- self.assertEqual([], result)
+ assert [] == result
with DAG("test_write", start_date=START_DATE):
task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}")
@@ -203,7 +201,7 @@ def test_write(self):
)
.first()
)
- self.assertEqual(('test_write', 'test', {'bash_command': 'echo test_val', 'env': None}), result)
+ assert ('test_write', 'test', {'bash_command': 'echo test_val', 'env': None}) == result
# Test that overwrite saves new values to the DB
Variable.delete("test_key")
@@ -224,9 +222,11 @@ def test_write(self):
)
.first()
)
- self.assertEqual(
- ('test_write', 'test', {'bash_command': 'echo test_val_updated', 'env': None}), result_updated
- )
+ assert (
+ 'test_write',
+ 'test',
+ {'bash_command': 'echo test_val_updated', 'env': None},
+ ) == result_updated
@mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
@mock.patch("airflow.settings.pod_mutation_hook")
@@ -245,9 +245,9 @@ def test_get_k8s_pod_yaml(self, mock_pod_mutation_hook):
# Test that pod_mutation_hook is called
mock_pod_mutation_hook.assert_called_once_with(mock.ANY)
- self.assertEqual(ti.dag_id, rtif.dag_id)
- self.assertEqual(ti.task_id, rtif.task_id)
- self.assertEqual(ti.execution_date, rtif.execution_date)
+ assert ti.dag_id == rtif.dag_id
+ assert ti.task_id == rtif.task_id
+ assert ti.execution_date == rtif.execution_date
expected_pod_yaml = {
'metadata': {
@@ -288,12 +288,12 @@ def test_get_k8s_pod_yaml(self, mock_pod_mutation_hook):
},
}
- self.assertEqual(expected_pod_yaml, rtif.k8s_pod_yaml)
+ assert expected_pod_yaml == rtif.k8s_pod_yaml
with create_session() as session:
session.add(rtif)
- self.assertEqual(expected_pod_yaml, RTIF.get_k8s_pod_yaml(ti=ti))
+ assert expected_pod_yaml == RTIF.get_k8s_pod_yaml(ti=ti)
# Test the else part of get_k8s_pod_yaml
# i.e. for the TIs that are not stored in RTIF table
@@ -302,4 +302,4 @@ def test_get_k8s_pod_yaml(self, mock_pod_mutation_hook):
task_2 = BashOperator(task_id="test2", bash_command="echo hello")
ti2 = TI(task_2, EXECUTION_DATE)
- self.assertIsNone(RTIF.get_k8s_pod_yaml(ti=ti2))
+ assert RTIF.get_k8s_pod_yaml(ti=ti2) is None
diff --git a/tests/models/test_sensorinstance.py b/tests/models/test_sensorinstance.py
index 6246df8c14d98..7eaa2ac034e10 100644
--- a/tests/models/test_sensorinstance.py
+++ b/tests/models/test_sensorinstance.py
@@ -32,7 +32,7 @@ def test_get_classpath(self):
"airflow.providers.apache.hive.sensors.named_hive_partition.NamedHivePartitionSensor"
)
- self.assertEqual(obj1_classpath, obj1_importpath)
+ assert obj1_classpath == obj1_importpath
def test_callable():
return
@@ -41,4 +41,4 @@ def test_callable():
obj3_classpath = SensorInstance.get_classpath(obj3)
obj3_importpath = "airflow.sensors.python.PythonSensor"
- self.assertEqual(obj3_classpath, obj3_importpath)
+ assert obj3_classpath == obj3_importpath
diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py
index 590ceed56220d..72380c55d153d 100644
--- a/tests/models/test_serialized_dag.py
+++ b/tests/models/test_serialized_dag.py
@@ -52,7 +52,7 @@ def tearDown(self):
def test_dag_fileloc_hash(self):
"""Verifies the correctness of hashing file path."""
- self.assertEqual(DagCode.dag_fileloc_hash('/airflow/dags/test_dag.py'), 33826252060516589)
+ assert DagCode.dag_fileloc_hash('/airflow/dags/test_dag.py') == 33826252060516589
def _write_example_dags(self):
example_dags = make_example_dags(example_dags_module)
@@ -66,10 +66,10 @@ def test_write_dag(self):
with create_session() as session:
for dag in example_dags.values():
- self.assertTrue(SDM.has_dag(dag.dag_id))
+ assert SDM.has_dag(dag.dag_id)
result = session.query(SDM.fileloc, SDM.data).filter(SDM.dag_id == dag.dag_id).one()
- self.assertTrue(result.fileloc == dag.full_filepath)
+ assert result.fileloc == dag.full_filepath
# Verifies JSON schema.
SerializedDAG.validate_schema(result.data)
@@ -88,30 +88,30 @@ def test_serialized_dag_is_updated_only_if_dag_is_changed(self):
SDM.write_dag(dag=example_bash_op_dag)
s_dag_1 = session.query(SDM).get(example_bash_op_dag.dag_id)
- self.assertEqual(s_dag_1.dag_hash, s_dag.dag_hash)
- self.assertEqual(s_dag.last_updated, s_dag_1.last_updated)
+ assert s_dag_1.dag_hash == s_dag.dag_hash
+ assert s_dag.last_updated == s_dag_1.last_updated
# Update DAG
example_bash_op_dag.tags += ["new_tag"]
- self.assertCountEqual(example_bash_op_dag.tags, ["example", "example2", "new_tag"])
+ assert set(example_bash_op_dag.tags) == {"example", "example2", "new_tag"}
SDM.write_dag(dag=example_bash_op_dag)
s_dag_2 = session.query(SDM).get(example_bash_op_dag.dag_id)
- self.assertNotEqual(s_dag.last_updated, s_dag_2.last_updated)
- self.assertNotEqual(s_dag.dag_hash, s_dag_2.dag_hash)
- self.assertEqual(s_dag_2.data["dag"]["tags"], ["example", "example2", "new_tag"])
+ assert s_dag.last_updated != s_dag_2.last_updated
+ assert s_dag.dag_hash != s_dag_2.dag_hash
+ assert s_dag_2.data["dag"]["tags"] == ["example", "example2", "new_tag"]
def test_read_dags(self):
"""DAGs can be read from database."""
example_dags = self._write_example_dags()
serialized_dags = SDM.read_all_dags()
- self.assertTrue(len(example_dags) == len(serialized_dags))
+ assert len(example_dags) == len(serialized_dags)
for dag_id, dag in example_dags.items():
serialized_dag = serialized_dags[dag_id]
- self.assertTrue(serialized_dag.dag_id == dag.dag_id)
- self.assertTrue(set(serialized_dag.task_dict) == set(dag.task_dict))
+ assert serialized_dag.dag_id == dag.dag_id
+ assert set(serialized_dag.task_dict) == set(dag.task_dict)
def test_remove_dags_by_id(self):
"""DAGs can be removed from database."""
@@ -122,7 +122,7 @@ def test_remove_dags_by_id(self):
# Tests removing by dag_id.
dag_removed_by_id = filtered_example_dags_list[0]
SDM.remove_dag(dag_removed_by_id.dag_id)
- self.assertFalse(SDM.has_dag(dag_removed_by_id.dag_id))
+ assert not SDM.has_dag(dag_removed_by_id.dag_id)
def test_remove_dags_by_filepath(self):
"""DAGs can be removed from database."""
@@ -136,7 +136,7 @@ def test_remove_dags_by_filepath(self):
example_dag_files = list({dag.full_filepath for dag in filtered_example_dags_list})
example_dag_files.remove(dag_removed_by_file.full_filepath)
SDM.remove_deleted_dags(example_dag_files)
- self.assertFalse(SDM.has_dag(dag_removed_by_file.dag_id))
+ assert not SDM.has_dag(dag_removed_by_file.dag_id)
def test_bulk_sync_to_db(self):
dags = [
diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py
index 814547b0ae4cd..00b14fda659ec 100644
--- a/tests/models/test_skipmixin.py
+++ b/tests/models/test_skipmixin.py
@@ -85,8 +85,8 @@ def test_skip_none_dagrun(self, mock_now):
def test_skip_none_tasks(self):
session = Mock()
SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], session=session)
- self.assertFalse(session.query.called)
- self.assertFalse(session.commit.called)
+ assert not session.query.called
+ assert not session.commit.called
def test_skip_all_except(self):
dag = DAG(
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index e6b53902ad516..012c547e9177b 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -121,11 +121,11 @@ def test_set_task_dates(self):
op1 = DummyOperator(task_id='op_1', owner='test')
- self.assertTrue(op1.start_date is None and op1.end_date is None)
+ assert op1.start_date is None and op1.end_date is None
# dag should assign its dates to op1 because op1 has no dates
dag.add_task(op1)
- self.assertTrue(op1.start_date == dag.start_date and op1.end_date == dag.end_date)
+ assert op1.start_date == dag.start_date and op1.end_date == dag.end_date
op2 = DummyOperator(
task_id='op_2',
@@ -136,7 +136,7 @@ def test_set_task_dates(self):
# dag should assign its dates to op2 because they are more restrictive
dag.add_task(op2)
- self.assertTrue(op2.start_date == dag.start_date and op2.end_date == dag.end_date)
+ assert op2.start_date == dag.start_date and op2.end_date == dag.end_date
op3 = DummyOperator(
task_id='op_3',
@@ -146,8 +146,8 @@ def test_set_task_dates(self):
)
# op3 should keep its dates because they are more restrictive
dag.add_task(op3)
- self.assertTrue(op3.start_date == DEFAULT_DATE + datetime.timedelta(days=1))
- self.assertTrue(op3.end_date == DEFAULT_DATE + datetime.timedelta(days=9))
+ assert op3.start_date == DEFAULT_DATE + datetime.timedelta(days=1)
+ assert op3.end_date == DEFAULT_DATE + datetime.timedelta(days=9)
def test_timezone_awareness(self):
naive_datetime = DEFAULT_DATE.replace(tzinfo=None)
@@ -156,7 +156,7 @@ def test_timezone_awareness(self):
op_no_dag = DummyOperator(task_id='op_no_dag')
ti = TI(task=op_no_dag, execution_date=naive_datetime)
- self.assertEqual(ti.execution_date, DEFAULT_DATE)
+ assert ti.execution_date == DEFAULT_DATE
# check with dag without localized execution_date
dag = DAG('dag', start_date=DEFAULT_DATE)
@@ -164,14 +164,14 @@ def test_timezone_awareness(self):
dag.add_task(op1)
ti = TI(task=op1, execution_date=naive_datetime)
- self.assertEqual(ti.execution_date, DEFAULT_DATE)
+ assert ti.execution_date == DEFAULT_DATE
# with dag and localized execution_date
tzinfo = pendulum.timezone("Europe/Amsterdam")
execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
utc_date = timezone.convert_to_utc(execution_date)
ti = TI(task=op1, execution_date=execution_date)
- self.assertEqual(ti.execution_date, utc_date)
+ assert ti.execution_date == utc_date
def test_task_naive_datetime(self):
naive_datetime = DEFAULT_DATE.replace(tzinfo=None)
@@ -180,8 +180,8 @@ def test_task_naive_datetime(self):
task_id='test_task_naive_datetime', start_date=naive_datetime, end_date=naive_datetime
)
- self.assertTrue(op_no_dag.start_date.tzinfo)
- self.assertTrue(op_no_dag.end_date.tzinfo)
+ assert op_no_dag.start_date.tzinfo
+ assert op_no_dag.end_date.tzinfo
def test_set_dag(self):
"""
@@ -192,24 +192,25 @@ def test_set_dag(self):
op = DummyOperator(task_id='op_1', owner='test')
# no dag assigned
- self.assertFalse(op.has_dag())
- self.assertRaises(AirflowException, getattr, op, 'dag')
+ assert not op.has_dag()
+ with pytest.raises(AirflowException):
+ getattr(op, 'dag')
# no improper assignment
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
op.dag = 1
op.dag = dag
# no reassignment
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
op.dag = dag2
# but assigning the same dag is ok
op.dag = dag
- self.assertIs(op.dag, dag)
- self.assertIn(op, dag.tasks)
+ assert op.dag is dag
+ assert op in dag.tasks
def test_infer_dag(self):
dag = DAG('dag', start_date=DEFAULT_DATE)
@@ -221,19 +222,22 @@ def test_infer_dag(self):
op4 = DummyOperator(task_id='test_op_4', owner='test', dag=dag2)
# double check dags
- self.assertEqual([i.has_dag() for i in [op1, op2, op3, op4]], [False, False, True, True])
+ assert [i.has_dag() for i in [op1, op2, op3, op4]] == [False, False, True, True]
# can't combine operators with no dags
- self.assertRaises(AirflowException, op1.set_downstream, op2)
+ with pytest.raises(AirflowException):
+ op1.set_downstream(op2)
# op2 should infer dag from op1
op1.dag = dag
op1.set_downstream(op2)
- self.assertIs(op2.dag, dag)
+ assert op2.dag is dag
# can't assign across multiple DAGs
- self.assertRaises(AirflowException, op1.set_downstream, op4)
- self.assertRaises(AirflowException, op1.set_downstream, [op3, op4])
+ with pytest.raises(AirflowException):
+ op1.set_downstream(op4)
+ with pytest.raises(AirflowException):
+ op1.set_downstream([op3, op4])
def test_bitshift_compose_operators(self):
dag = DAG('dag', start_date=DEFAULT_DATE)
@@ -245,8 +249,8 @@ def test_bitshift_compose_operators(self):
op1 >> op2 << op3
# op2 should be downstream of both
- self.assertIn(op2, op1.downstream_list)
- self.assertIn(op2, op3.downstream_list)
+ assert op2 in op1.downstream_list
+ assert op2 in op3.downstream_list
@patch.object(DAG, 'get_concurrency_reached')
def test_requeue_over_dag_concurrency(self, mock_concurrency_reached):
@@ -266,7 +270,7 @@ def test_requeue_over_dag_concurrency(self, mock_concurrency_reached):
session.add(ti)
session.commit()
ti.run()
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
def test_requeue_over_task_concurrency(self):
dag = DAG(
@@ -283,7 +287,7 @@ def test_requeue_over_task_concurrency(self):
session.add(ti)
session.commit()
ti.run()
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
def test_requeue_over_pool_concurrency(self):
dag = DAG(
@@ -302,7 +306,7 @@ def test_requeue_over_pool_concurrency(self):
session.add(ti)
session.commit()
ti.run()
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
def test_not_requeue_non_requeueable_task_instance(self):
dag = models.DAG(dag_id='test_not_requeue_non_requeueable_task_instance')
@@ -333,7 +337,7 @@ def test_not_requeue_non_requeueable_task_instance(self):
for class_name, (dep_patch, method_patch) in patch_dict.items():
method_patch.return_value = iter([TIDepStatus('mock_' + class_name, False, 'mock')])
ti.run()
- self.assertEqual(ti.state, State.QUEUED)
+ assert ti.state == State.QUEUED
dep_patch.return_value = TIDepStatus('mock_' + class_name, True, 'mock')
for (dep_patch, method_patch) in patch_dict.values():
@@ -366,7 +370,7 @@ def test_mark_non_runnable_task_as_success(self):
)
session.commit()
ti.run(mark_success=True)
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
def test_run_pooling_task(self):
"""
@@ -390,7 +394,7 @@ def test_run_pooling_task(self):
ti.run()
db.clear_db_pools()
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
def test_pool_slots_property(self):
"""
@@ -409,7 +413,8 @@ def create_task_instance():
)
return TI(task=task, execution_date=timezone.utcnow())
- self.assertRaises(AirflowException, create_task_instance)
+ with pytest.raises(AirflowException):
+ create_task_instance()
@provide_session
def test_ti_updates_with_task(self, session=None):
@@ -434,7 +439,7 @@ def test_ti_updates_with_task(self, session=None):
ti.run(session=session)
tis = dag.get_task_instances()
- self.assertEqual({'foo': 'bar'}, tis[0].executor_config)
+ assert {'foo': 'bar'} == tis[0].executor_config
with models.DAG(dag_id='test_run_pooling_task') as dag:
task2 = DummyOperator(
task_id='test_run_pooling_task_op',
@@ -453,7 +458,7 @@ def test_ti_updates_with_task(self, session=None):
)
ti.run(session=session)
tis = dag.get_task_instances()
- self.assertEqual({'bar': 'baz'}, tis[1].executor_config)
+ assert {'bar': 'baz'} == tis[1].executor_config
session.rollback()
def test_run_pooling_task_with_mark_success(self):
@@ -478,7 +483,7 @@ def test_run_pooling_task_with_mark_success(self):
run_type=DagRunType.SCHEDULED,
)
ti.run(mark_success=True)
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
def test_run_pooling_task_with_skip(self):
"""
@@ -504,7 +509,7 @@ def raise_skip_exception():
run_type=DagRunType.SCHEDULED,
)
ti.run()
- self.assertEqual(State.SKIPPED, ti.state)
+ assert State.SKIPPED == ti.state
def test_retry_delay(self):
"""
@@ -534,20 +539,20 @@ def run_with_error(ti):
run_type=DagRunType.SCHEDULED,
)
- self.assertEqual(ti.try_number, 1)
+ assert ti.try_number == 1
# first run -- up for retry
run_with_error(ti)
- self.assertEqual(ti.state, State.UP_FOR_RETRY)
- self.assertEqual(ti.try_number, 2)
+ assert ti.state == State.UP_FOR_RETRY
+ assert ti.try_number == 2
# second run -- still up for retry because retry_delay hasn't expired
run_with_error(ti)
- self.assertEqual(ti.state, State.UP_FOR_RETRY)
+ assert ti.state == State.UP_FOR_RETRY
# third run -- failed
time.sleep(3)
run_with_error(ti)
- self.assertEqual(ti.state, State.FAILED)
+ assert ti.state == State.FAILED
def test_retry_handling(self):
"""
@@ -573,19 +578,19 @@ def run_with_error(ti):
pass
ti = TI(task=task, execution_date=timezone.utcnow())
- self.assertEqual(ti.try_number, 1)
+ assert ti.try_number == 1
# first run -- up for retry
run_with_error(ti)
- self.assertEqual(ti.state, State.UP_FOR_RETRY)
- self.assertEqual(ti._try_number, 1)
- self.assertEqual(ti.try_number, 2)
+ assert ti.state == State.UP_FOR_RETRY
+ assert ti._try_number == 1
+ assert ti.try_number == 2
# second run -- fail
run_with_error(ti)
- self.assertEqual(ti.state, State.FAILED)
- self.assertEqual(ti._try_number, 2)
- self.assertEqual(ti.try_number, 3)
+ assert ti.state == State.FAILED
+ assert ti._try_number == 2
+ assert ti.try_number == 3
# Clear the TI state since you can't run a task with a FAILED state without
# clearing it first
@@ -593,17 +598,17 @@ def run_with_error(ti):
# third run -- up for retry
run_with_error(ti)
- self.assertEqual(ti.state, State.UP_FOR_RETRY)
- self.assertEqual(ti._try_number, 3)
- self.assertEqual(ti.try_number, 4)
+ assert ti.state == State.UP_FOR_RETRY
+ assert ti._try_number == 3
+ assert ti.try_number == 4
# fourth run -- fail
run_with_error(ti)
ti.refresh_from_db()
- self.assertEqual(ti.state, State.FAILED)
- self.assertEqual(ti._try_number, 4)
- self.assertEqual(ti.try_number, 5)
- self.assertEqual(RenderedTaskInstanceFields.get_templated_fields(ti), expected_rendered_ti_fields)
+ assert ti.state == State.FAILED
+ assert ti._try_number == 4
+ assert ti.try_number == 5
+ assert RenderedTaskInstanceFields.get_templated_fields(ti) == expected_rendered_ti_fields
def test_next_retry_datetime(self):
delay = datetime.timedelta(seconds=30)
@@ -627,27 +632,27 @@ def test_next_retry_datetime(self):
date = ti.next_retry_datetime()
# between 30 * 2^0.5 and 30 * 2^1 (15 and 30)
period = ti.end_date.add(seconds=30) - ti.end_date.add(seconds=15)
- self.assertTrue(date in period)
+ assert date in period
ti.try_number = 3
date = ti.next_retry_datetime()
# between 30 * 2^2 and 30 * 2^3 (120 and 240)
period = ti.end_date.add(seconds=240) - ti.end_date.add(seconds=120)
- self.assertTrue(date in period)
+ assert date in period
ti.try_number = 5
date = ti.next_retry_datetime()
# between 30 * 2^4 and 30 * 2^5 (480 and 960)
period = ti.end_date.add(seconds=960) - ti.end_date.add(seconds=480)
- self.assertTrue(date in period)
+ assert date in period
ti.try_number = 9
date = ti.next_retry_datetime()
- self.assertEqual(date, ti.end_date + max_delay)
+ assert date == ti.end_date + max_delay
ti.try_number = 50
date = ti.next_retry_datetime()
- self.assertEqual(date, ti.end_date + max_delay)
+ assert date == ti.end_date + max_delay
def test_next_retry_datetime_short_intervals(self):
delay = datetime.timedelta(seconds=1)
@@ -671,7 +676,7 @@ def test_next_retry_datetime_short_intervals(self):
date = ti.next_retry_datetime()
# between 1 * 2^0.5 and 1 * 2^1 (15 and 30)
period = ti.end_date.add(seconds=15) - ti.end_date.add(seconds=1)
- self.assertTrue(date in period)
+ assert date in period
def test_reschedule_handling(self):
"""
@@ -701,8 +706,8 @@ def func():
)
ti = TI(task=task, execution_date=timezone.utcnow())
- self.assertEqual(ti._try_number, 0)
- self.assertEqual(ti.try_number, 1)
+ assert ti._try_number == 0
+ assert ti.try_number == 1
dag.create_dagrun(
execution_date=ti.execution_date,
@@ -726,14 +731,14 @@ def run_ti_and_assert(
if not fail:
raise
ti.refresh_from_db()
- self.assertEqual(ti.state, expected_state)
- self.assertEqual(ti._try_number, expected_try_number)
- self.assertEqual(ti.try_number, expected_try_number + 1)
- self.assertEqual(ti.start_date, expected_start_date)
- self.assertEqual(ti.end_date, expected_end_date)
- self.assertEqual(ti.duration, expected_duration)
+ assert ti.state == expected_state
+ assert ti._try_number == expected_try_number
+ assert ti.try_number == expected_try_number + 1
+ assert ti.start_date == expected_start_date
+ assert ti.end_date == expected_end_date
+ assert ti.duration == expected_duration
trs = TaskReschedule.find_for_task_instance(ti) # pylint: disable=no-value-for-parameter
- self.assertEqual(len(trs), expected_task_reschedule_count)
+ assert len(trs) == expected_task_reschedule_count
date1 = timezone.utcnow()
date2 = date1 + datetime.timedelta(minutes=1)
@@ -761,8 +766,8 @@ def run_ti_and_assert(
# Clear the task instance.
dag.clear()
ti.refresh_from_db()
- self.assertEqual(ti.state, State.NONE)
- self.assertEqual(ti._try_number, 1)
+ assert ti.state == State.NONE
+ assert ti._try_number == 1
# Run again after clearing with reschedules and a retry.
# The retry increments the try number, and for that try no reschedule is expected.
@@ -808,8 +813,8 @@ def func():
)
ti = TI(task=task, execution_date=timezone.utcnow())
- self.assertEqual(ti._try_number, 0)
- self.assertEqual(ti.try_number, 1)
+ assert ti._try_number == 0
+ assert ti.try_number == 1
def run_ti_and_assert(
run_date,
@@ -827,14 +832,14 @@ def run_ti_and_assert(
if not fail:
raise
ti.refresh_from_db()
- self.assertEqual(ti.state, expected_state)
- self.assertEqual(ti._try_number, expected_try_number)
- self.assertEqual(ti.try_number, expected_try_number + 1)
- self.assertEqual(ti.start_date, expected_start_date)
- self.assertEqual(ti.end_date, expected_end_date)
- self.assertEqual(ti.duration, expected_duration)
+ assert ti.state == expected_state
+ assert ti._try_number == expected_try_number
+ assert ti.try_number == expected_try_number + 1
+ assert ti.start_date == expected_start_date
+ assert ti.end_date == expected_end_date
+ assert ti.duration == expected_duration
trs = TaskReschedule.find_for_task_instance(ti) # pylint: disable=no-value-for-parameter
- self.assertEqual(len(trs), expected_task_reschedule_count)
+ assert len(trs) == expected_task_reschedule_count
date1 = timezone.utcnow()
@@ -844,11 +849,11 @@ def run_ti_and_assert(
# Clear the task instance.
dag.clear()
ti.refresh_from_db()
- self.assertEqual(ti.state, State.NONE)
- self.assertEqual(ti._try_number, 0)
+ assert ti.state == State.NONE
+ assert ti._try_number == 0
# Check that reschedules for ti have also been cleared.
trs = TaskReschedule.find_for_task_instance(ti) # pylint: disable=no-value-for-parameter
- self.assertFalse(trs)
+ assert not trs
def test_depends_on_past(self):
dag = DAG(dag_id='test_depends_on_past', start_date=DEFAULT_DATE)
@@ -873,12 +878,12 @@ def test_depends_on_past(self):
# depends_on_past prevents the run
task.run(start_date=run_date, end_date=run_date, ignore_first_depends_on_past=False)
ti.refresh_from_db()
- self.assertIs(ti.state, None)
+ assert ti.state is None
# ignore first depends_on_past to allow the run
task.run(start_date=run_date, end_date=run_date, ignore_first_depends_on_past=True)
ti.refresh_from_db()
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
# Parameterized tests to check for the correct firing
# of the trigger_rule under various circumstances
@@ -957,8 +962,8 @@ def test_check_task_dependencies(
)
completed = all(dep.passed for dep in dep_results)
- self.assertEqual(completed, expect_completed)
- self.assertEqual(ti.state, expect_state)
+ assert completed == expect_completed
+ assert ti.state == expect_state
def test_respects_prev_dagrun_dep(self):
with DAG(dag_id='test_dag'):
@@ -969,11 +974,11 @@ def test_respects_prev_dagrun_dep(self):
with patch(
'airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses', return_value=failing_status
):
- self.assertFalse(ti.are_dependencies_met())
+ assert not ti.are_dependencies_met()
with patch(
'airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses', return_value=passing_status
):
- self.assertTrue(ti.are_dependencies_met())
+ assert ti.are_dependencies_met()
@parameterized.expand(
[
@@ -994,7 +999,7 @@ def test_are_dependents_done(self, downstream_ti_state, expected_are_dependents_
downstream_ti = TI(downstream_task, DEFAULT_DATE)
downstream_ti.set_state(downstream_ti_state)
- self.assertEqual(ti.are_dependents_done(), expected_are_dependents_done)
+ assert ti.are_dependents_done() == expected_are_dependents_done
def test_xcom_pull(self):
"""
@@ -1020,19 +1025,19 @@ def test_xcom_pull(self):
# Pull with no arguments
result = ti1.xcom_pull()
- self.assertEqual(result, None)
+ assert result is None
# Pull the value pushed most recently by any task.
result = ti1.xcom_pull(key='foo')
- self.assertIn(result, 'baz')
+ assert result in 'baz'
# Pull the value pushed by the first task
result = ti1.xcom_pull(task_ids='test_xcom_1', key='foo')
- self.assertEqual(result, 'bar')
+ assert result == 'bar'
# Pull the value pushed by the second task
result = ti1.xcom_pull(task_ids='test_xcom_2', key='foo')
- self.assertEqual(result, 'baz')
+ assert result == 'baz'
# Pull the values pushed by both tasks & Verify Order of task_ids pass & values returned
result = ti1.xcom_pull(task_ids=['test_xcom_1', 'test_xcom_2'], key='foo')
- self.assertEqual(result, ['bar', 'baz'])
+ assert result == ['bar', 'baz']
def test_xcom_pull_after_success(self):
"""
@@ -1060,19 +1065,19 @@ def test_xcom_pull_after_success(self):
ti.run(mark_success=True)
ti.xcom_push(key=key, value=value)
- self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)
+ assert ti.xcom_pull(task_ids='test_xcom', key=key) == value
ti.run()
# The second run and assert is to handle AIRFLOW-131 (don't clear on
# prior success)
- self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)
+ assert ti.xcom_pull(task_ids='test_xcom', key=key) == value
# Test AIRFLOW-703: Xcom shouldn't be cleared if the task doesn't
# execute, even if dependencies are ignored
ti.run(ignore_all_deps=True, mark_success=True)
- self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)
+ assert ti.xcom_pull(task_ids='test_xcom', key=key) == value
# Xcom IS finally cleared once task has executed
ti.run(ignore_all_deps=True)
- self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), None)
+ assert ti.xcom_pull(task_ids='test_xcom', key=key) is None
def test_xcom_pull_different_execution_date(self):
"""
@@ -1101,7 +1106,7 @@ def test_xcom_pull_different_execution_date(self):
ti.run(mark_success=True)
ti.xcom_push(key=key, value=value)
- self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)
+ assert ti.xcom_pull(task_ids='test_xcom', key=key) == value
ti.run()
exec_date += datetime.timedelta(days=1)
ti = TI(task=task, execution_date=exec_date)
@@ -1109,9 +1114,9 @@ def test_xcom_pull_different_execution_date(self):
# We have set a new execution date (and did not pass in
# 'include_prior_dates'which means this task should now have a cleared
# xcom value
- self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), None)
+ assert ti.xcom_pull(task_ids='test_xcom', key=key) is None
# We *should* get a value using 'include_prior_dates'
- self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key, include_prior_dates=True), value)
+ assert ti.xcom_pull(task_ids='test_xcom', key=key, include_prior_dates=True) == value
def test_xcom_push_flag(self):
"""
@@ -1137,7 +1142,7 @@ def test_xcom_push_flag(self):
run_type=DagRunType.SCHEDULED,
)
ti.run()
- self.assertEqual(ti.xcom_pull(task_ids=task_id, key=models.XCOM_RETURN_KEY), None)
+ assert ti.xcom_pull(task_ids=task_id, key=models.XCOM_RETURN_KEY) is None
def test_post_execute_hook(self):
"""
@@ -1163,18 +1168,18 @@ def post_execute(self, context, result=None):
)
ti = TI(task=task, execution_date=timezone.utcnow())
- with self.assertRaises(TestError):
+ with pytest.raises(TestError):
ti.run()
def test_check_and_change_state_before_execution(self):
dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
ti = TI(task=task, execution_date=timezone.utcnow())
- self.assertEqual(ti._try_number, 0)
- self.assertTrue(ti.check_and_change_state_before_execution())
+ assert ti._try_number == 0
+ assert ti.check_and_change_state_before_execution()
# State should be running, and try_number column should be incremented
- self.assertEqual(ti.state, State.RUNNING)
- self.assertEqual(ti._try_number, 1)
+ assert ti.state == State.RUNNING
+ assert ti._try_number == 1
def test_check_and_change_state_before_execution_dep_not_met(self):
dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
@@ -1182,7 +1187,7 @@ def test_check_and_change_state_before_execution_dep_not_met(self):
task2 = DummyOperator(task_id='task2', dag=dag, start_date=DEFAULT_DATE)
task >> task2
ti = TI(task=task2, execution_date=timezone.utcnow())
- self.assertFalse(ti.check_and_change_state_before_execution())
+ assert not ti.check_and_change_state_before_execution()
def test_try_number(self):
"""
@@ -1191,12 +1196,12 @@ def test_try_number(self):
dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
ti = TI(task=task, execution_date=timezone.utcnow())
- self.assertEqual(1, ti.try_number)
+ assert 1 == ti.try_number
ti.try_number = 2
ti.state = State.RUNNING
- self.assertEqual(2, ti.try_number)
+ assert 2 == ti.try_number
ti.state = State.SUCCESS
- self.assertEqual(3, ti.try_number)
+ assert 3 == ti.try_number
def test_get_num_running_task_instances(self):
session = settings.Session()
@@ -1217,9 +1222,9 @@ def test_get_num_running_task_instances(self):
session.add(ti3)
session.commit()
- self.assertEqual(1, ti1.get_num_running_task_instances(session=session))
- self.assertEqual(1, ti2.get_num_running_task_instances(session=session))
- self.assertEqual(1, ti3.get_num_running_task_instances(session=session))
+ assert 1 == ti1.get_num_running_task_instances(session=session)
+ assert 1 == ti2.get_num_running_task_instances(session=session)
+ assert 1 == ti3.get_num_running_task_instances(session=session)
# def test_log_url(self):
# now = pendulum.now('Europe/Brussels')
@@ -1244,7 +1249,7 @@ def test_log_url(self):
'&task_id=op'
'&dag_id=dag'
)
- self.assertEqual(ti.log_url, expected_url)
+ assert ti.log_url == expected_url
def test_mark_success_url(self):
now = pendulum.now('Europe/Brussels')
@@ -1254,9 +1259,9 @@ def test_mark_success_url(self):
query = urllib.parse.parse_qs(
urllib.parse.urlparse(ti.mark_success_url).query, keep_blank_values=True, strict_parsing=True
)
- self.assertEqual(query['dag_id'][0], 'dag')
- self.assertEqual(query['task_id'][0], 'op')
- self.assertEqual(pendulum.parse(query['execution_date'][0]), now)
+ assert query['dag_id'][0] == 'dag'
+ assert query['task_id'][0] == 'op'
+ assert pendulum.parse(query['execution_date'][0]) == now
def test_overwrite_params_with_dag_run_conf(self):
task = DummyOperator(task_id='op')
@@ -1267,7 +1272,7 @@ def test_overwrite_params_with_dag_run_conf(self):
ti.overwrite_params_with_dag_run_conf(params, dag_run)
- self.assertEqual(True, params["override"])
+ assert params["override"] is True
def test_overwrite_params_with_dag_run_none(self):
task = DummyOperator(task_id='op')
@@ -1276,7 +1281,7 @@ def test_overwrite_params_with_dag_run_none(self):
ti.overwrite_params_with_dag_run_conf(params, None)
- self.assertEqual(False, params["override"])
+ assert params["override"] is False
def test_overwrite_params_with_dag_run_conf_none(self):
task = DummyOperator(task_id='op')
@@ -1286,7 +1291,7 @@ def test_overwrite_params_with_dag_run_conf_none(self):
ti.overwrite_params_with_dag_run_conf(params, dag_run)
- self.assertEqual(False, params["override"])
+ assert params["override"] is False
@patch('airflow.models.taskinstance.send_email')
def test_email_alert(self, mock_send_email):
@@ -1303,10 +1308,10 @@ def test_email_alert(self, mock_send_email):
pass
(email, title, body), _ = mock_send_email.call_args
- self.assertEqual(email, 'to')
- self.assertIn('test_email_alert', title)
- self.assertIn('test_email_alert', body)
- self.assertIn('Try 1', body)
+ assert email == 'to'
+ assert 'test_email_alert' in title
+ assert 'test_email_alert' in body
+ assert 'Try 1' in body
@conf_vars(
{
@@ -1335,9 +1340,9 @@ def test_email_alert_with_config(self, mock_send_email):
pass
(email, title, body), _ = mock_send_email.call_args
- self.assertEqual(email, 'to')
- self.assertEqual('template: test_email_alert_with_config', title)
- self.assertEqual('template: test_email_alert_with_config', body)
+ assert email == 'to'
+ assert 'template: test_email_alert_with_config' == title
+ assert 'template: test_email_alert_with_config' == body
def test_set_duration(self):
task = DummyOperator(task_id='op', email='test@test.test')
@@ -1348,13 +1353,13 @@ def test_set_duration(self):
ti.start_date = datetime.datetime(2018, 10, 1, 1)
ti.end_date = datetime.datetime(2018, 10, 1, 2)
ti.set_duration()
- self.assertEqual(ti.duration, 3600)
+ assert ti.duration == 3600
def test_set_duration_empty_dates(self):
task = DummyOperator(task_id='op', email='test@test.test')
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.set_duration()
- self.assertIsNone(ti.duration)
+ assert ti.duration is None
def test_success_callback_no_race_condition(self):
callback_wrapper = CallbackWrapper()
@@ -1384,10 +1389,10 @@ def test_success_callback_no_race_condition(self):
callback_wrapper.wrap_task_instance(ti)
ti._run_raw_task()
- self.assertTrue(callback_wrapper.callback_ran)
- self.assertEqual(callback_wrapper.task_state_in_callback, State.RUNNING)
+ assert callback_wrapper.callback_ran
+ assert callback_wrapper.task_state_in_callback == State.RUNNING
ti.refresh_from_db()
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
@staticmethod
def _test_previous_dates_setup(
@@ -1438,11 +1443,11 @@ def test_previous_ti(self, _, schedule_interval, catchup) -> None:
ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario)
- self.assertIsNone(ti_list[0].get_previous_ti())
+ assert ti_list[0].get_previous_ti() is None
- self.assertEqual(ti_list[2].get_previous_ti().execution_date, ti_list[1].execution_date)
+ assert ti_list[2].get_previous_ti().execution_date == ti_list[1].execution_date
- self.assertNotEqual(ti_list[2].get_previous_ti().execution_date, ti_list[0].execution_date)
+ assert ti_list[2].get_previous_ti().execution_date != ti_list[0].execution_date
@parameterized.expand(_prev_dates_param_list)
def test_previous_ti_success(self, _, schedule_interval, catchup) -> None:
@@ -1451,16 +1456,12 @@ def test_previous_ti_success(self, _, schedule_interval, catchup) -> None:
ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario)
- self.assertIsNone(ti_list[0].get_previous_ti(state=State.SUCCESS))
- self.assertIsNone(ti_list[1].get_previous_ti(state=State.SUCCESS))
+ assert ti_list[0].get_previous_ti(state=State.SUCCESS) is None
+ assert ti_list[1].get_previous_ti(state=State.SUCCESS) is None
- self.assertEqual(
- ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date, ti_list[1].execution_date
- )
+ assert ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date == ti_list[1].execution_date
- self.assertNotEqual(
- ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date, ti_list[2].execution_date
- )
+ assert ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date != ti_list[2].execution_date
@parameterized.expand(_prev_dates_param_list)
def test_previous_execution_date_success(self, _, schedule_interval, catchup) -> None:
@@ -1469,14 +1470,10 @@ def test_previous_execution_date_success(self, _, schedule_interval, catchup) ->
ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario)
- self.assertIsNone(ti_list[0].get_previous_execution_date(state=State.SUCCESS))
- self.assertIsNone(ti_list[1].get_previous_execution_date(state=State.SUCCESS))
- self.assertEqual(
- ti_list[3].get_previous_execution_date(state=State.SUCCESS), ti_list[1].execution_date
- )
- self.assertNotEqual(
- ti_list[3].get_previous_execution_date(state=State.SUCCESS), ti_list[2].execution_date
- )
+ assert ti_list[0].get_previous_execution_date(state=State.SUCCESS) is None
+ assert ti_list[1].get_previous_execution_date(state=State.SUCCESS) is None
+ assert ti_list[3].get_previous_execution_date(state=State.SUCCESS) == ti_list[1].execution_date
+ assert ti_list[3].get_previous_execution_date(state=State.SUCCESS) != ti_list[2].execution_date
@parameterized.expand(_prev_dates_param_list)
def test_previous_start_date_success(self, _, schedule_interval, catchup) -> None:
@@ -1485,16 +1482,10 @@ def test_previous_start_date_success(self, _, schedule_interval, catchup) -> Non
ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario)
- self.assertIsNone(ti_list[0].get_previous_start_date(state=State.SUCCESS))
- self.assertIsNone(ti_list[1].get_previous_start_date(state=State.SUCCESS))
- self.assertEqual(
- ti_list[3].get_previous_start_date(state=State.SUCCESS),
- ti_list[1].start_date,
- )
- self.assertNotEqual(
- ti_list[3].get_previous_start_date(state=State.SUCCESS),
- ti_list[2].start_date,
- )
+ assert ti_list[0].get_previous_start_date(state=State.SUCCESS) is None
+ assert ti_list[1].get_previous_start_date(state=State.SUCCESS) is None
+ assert ti_list[3].get_previous_start_date(state=State.SUCCESS) == ti_list[1].start_date
+ assert ti_list[3].get_previous_start_date(state=State.SUCCESS) != ti_list[2].start_date
def test_pendulum_template_dates(self):
dag = models.DAG(
@@ -1508,9 +1499,9 @@ def test_pendulum_template_dates(self):
template_context = ti.get_template_context()
- self.assertIsInstance(template_context["execution_date"], pendulum.DateTime)
- self.assertIsInstance(template_context["next_execution_date"], pendulum.DateTime)
- self.assertIsInstance(template_context["prev_execution_date"], pendulum.DateTime)
+ assert isinstance(template_context["execution_date"], pendulum.DateTime)
+ assert isinstance(template_context["next_execution_date"], pendulum.DateTime)
+ assert isinstance(template_context["prev_execution_date"], pendulum.DateTime)
@parameterized.expand(
[
@@ -1532,7 +1523,7 @@ def test_template_with_variable(self, content, expected_output):
ti = TI(task=task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()
result = task.render_template(content, context)
- self.assertEqual(result, expected_output)
+ assert result == expected_output
def test_template_with_variable_missing(self):
"""
@@ -1543,7 +1534,7 @@ def test_template_with_variable_missing(self):
ti = TI(task=task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
task.render_template('{{ var.value.get("missing_variable") }}', context)
@parameterized.expand(
@@ -1567,7 +1558,7 @@ def test_template_with_json_variable(self, content, expected_output):
ti = TI(task=task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()
result = task.render_template(content, context)
- self.assertEqual(result, expected_output)
+ assert result == expected_output
def test_template_with_json_variable_missing(self):
with DAG('test-dag', start_date=DEFAULT_DATE):
@@ -1575,7 +1566,7 @@ def test_template_with_json_variable_missing(self):
ti = TI(task=task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
task.render_template('{{ var.json.get("missing_variable") }}', context)
def test_execute_callback(self):
@@ -1584,7 +1575,7 @@ def test_execute_callback(self):
def on_execute_callable(context):
nonlocal called
called = True
- self.assertEqual(context['dag_run'].dag_id, 'test_dagrun_execute_callback')
+ assert context['dag_run'].dag_id == 'test_dagrun_execute_callback'
dag = DAG(
'test_execute_callback',
@@ -1686,7 +1677,7 @@ def fail():
ti.run()
except AirflowFailException:
pass # expected
- self.assertEqual(State.FAILED, ti.state)
+ assert State.FAILED == ti.state
def test_retries_on_other_exceptions(self):
def fail():
@@ -1706,15 +1697,13 @@ def fail():
ti.run()
except AirflowException:
pass # expected
- self.assertEqual(State.UP_FOR_RETRY, ti.state)
+ assert State.UP_FOR_RETRY == ti.state
def _env_var_check_callback(self):
- self.assertEqual('test_echo_env_variables', os.environ['AIRFLOW_CTX_DAG_ID'])
- self.assertEqual('hive_in_python_op', os.environ['AIRFLOW_CTX_TASK_ID'])
- self.assertEqual(DEFAULT_DATE.isoformat(), os.environ['AIRFLOW_CTX_EXECUTION_DATE'])
- self.assertEqual(
- DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE), os.environ['AIRFLOW_CTX_DAG_RUN_ID']
- )
+ assert 'test_echo_env_variables' == os.environ['AIRFLOW_CTX_DAG_ID']
+ assert 'hive_in_python_op' == os.environ['AIRFLOW_CTX_TASK_ID']
+ assert DEFAULT_DATE.isoformat() == os.environ['AIRFLOW_CTX_EXECUTION_DATE']
+ assert DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE) == os.environ['AIRFLOW_CTX_DAG_RUN_ID']
def test_echo_env_variables(self):
dag = DAG(
@@ -1739,7 +1728,7 @@ def test_echo_env_variables(self):
session.commit()
ti._run_raw_task()
ti.refresh_from_db()
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
@patch.object(Stats, 'incr')
def test_task_stats(self, stats_mock):
@@ -1764,8 +1753,8 @@ def test_task_stats(self, stats_mock):
ti._run_raw_task()
ti.refresh_from_db()
stats_mock.assert_called_with(f'ti.finish.{dag.dag_id}.{op.task_id}.{ti.state}')
- self.assertIn(call(f'ti.start.{dag.dag_id}.{op.task_id}'), stats_mock.mock_calls)
- self.assertEqual(stats_mock.call_count, 5)
+ assert call(f'ti.start.{dag.dag_id}.{op.task_id}') in stats_mock.mock_calls
+ assert stats_mock.call_count == 5
def test_generate_command_default_param(self):
dag_id = 'test_generate_command_default_param'
@@ -1808,7 +1797,7 @@ def test_get_rendered_template_fields(self):
new_ti = TI(task=new_task, execution_date=DEFAULT_DATE)
new_ti.get_rendered_template_fields()
- self.assertEqual("op1", ti.task.bash_command)
+ assert "op1" == ti.task.bash_command
# CleanUp
with create_session() as session:
@@ -1863,7 +1852,7 @@ def test_get_rendered_k8s_spec(self):
with create_session() as session:
rtif = RenderedTaskInstanceFields(ti)
session.add(rtif)
- self.assertEqual(rtif.k8s_pod_yaml, expected_pod_spec)
+ assert rtif.k8s_pod_yaml == expected_pod_spec
# Create new TI for the same Task
with DAG('test_get_rendered_k8s_spec', start_date=DEFAULT_DATE):
@@ -1872,7 +1861,7 @@ def test_get_rendered_k8s_spec(self):
new_ti = TI(task=new_task, execution_date=DEFAULT_DATE)
pod_spec = new_ti.get_rendered_k8s_spec()
- self.assertEqual(expected_pod_spec, pod_spec)
+ assert expected_pod_spec == pod_spec
# CleanUp
with create_session() as session:
@@ -1881,7 +1870,7 @@ def test_get_rendered_k8s_spec(self):
def validate_ti_states(self, dag_run, ti_state_mapping, error_message):
for task_id, expected_state in ti_state_mapping.items():
task_instance = dag_run.get_task_instance(task_id=task_id)
- self.assertEqual(task_instance.state, expected_state, error_message)
+ assert task_instance.state == expected_state, error_message
@parameterized.expand(
[
@@ -2089,15 +2078,15 @@ def test_operator_field_with_serialization(self):
dag = DAG('test_queries', start_date=DEFAULT_DATE)
task = DummyOperator(task_id='op', dag=dag)
- self.assertEqual(task.task_type, 'DummyOperator')
+ assert task.task_type == 'DummyOperator'
# Verify that ti.operator field renders correctly "without" Serialization
ti = TI(task=task, execution_date=datetime.datetime.now())
- self.assertEqual(ti.operator, "DummyOperator")
+ assert ti.operator == "DummyOperator"
serialized_op = SerializedBaseOperator.serialize_operator(task)
deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op)
- self.assertEqual(deserialized_op.task_type, 'DummyOperator')
+ assert deserialized_op.task_type == 'DummyOperator'
# Verify that ti.operator field renders correctly "with" Serialization
ser_ti = TI(task=deserialized_op, execution_date=datetime.datetime.now())
- self.assertEqual(ser_ti.operator, "DummyOperator")
+ assert ser_ti.operator == "DummyOperator"
diff --git a/tests/models/test_variable.py b/tests/models/test_variable.py
index 91fb36914c40d..e5e1d7aeda605 100644
--- a/tests/models/test_variable.py
+++ b/tests/models/test_variable.py
@@ -19,6 +19,7 @@
import unittest
from unittest import mock
+import pytest
from cryptography.fernet import Fernet
from parameterized import parameterized
@@ -45,8 +46,8 @@ def test_variable_no_encryption(self):
Variable.set('key', 'value')
session = settings.Session()
test_var = session.query(Variable).filter(Variable.key == 'key').one()
- self.assertFalse(test_var.is_encrypted)
- self.assertEqual(test_var.val, 'value')
+ assert not test_var.is_encrypted
+ assert test_var.val == 'value'
@conf_vars({('core', 'fernet_key'): Fernet.generate_key().decode()})
def test_variable_with_encryption(self):
@@ -56,8 +57,8 @@ def test_variable_with_encryption(self):
Variable.set('key', 'value')
session = settings.Session()
test_var = session.query(Variable).filter(Variable.key == 'key').one()
- self.assertTrue(test_var.is_encrypted)
- self.assertEqual(test_var.val, 'value')
+ assert test_var.is_encrypted
+ assert test_var.val == 'value'
@parameterized.expand(['value', ''])
def test_var_with_encryption_rotate_fernet_key(self, test_value):
@@ -71,83 +72,79 @@ def test_var_with_encryption_rotate_fernet_key(self, test_value):
Variable.set('key', test_value)
session = settings.Session()
test_var = session.query(Variable).filter(Variable.key == 'key').one()
- self.assertTrue(test_var.is_encrypted)
- self.assertEqual(test_var.val, test_value)
- self.assertEqual(Fernet(key1).decrypt(test_var._val.encode()), test_value.encode())
+ assert test_var.is_encrypted
+ assert test_var.val == test_value
+ assert Fernet(key1).decrypt(test_var._val.encode()) == test_value.encode()
# Test decrypt of old value with new key
with conf_vars({('core', 'fernet_key'): ','.join([key2.decode(), key1.decode()])}):
crypto._fernet = None
- self.assertEqual(test_var.val, test_value)
+ assert test_var.val == test_value
# Test decrypt of new value with new key
test_var.rotate_fernet_key()
- self.assertTrue(test_var.is_encrypted)
- self.assertEqual(test_var.val, test_value)
- self.assertEqual(Fernet(key2).decrypt(test_var._val.encode()), test_value.encode())
+ assert test_var.is_encrypted
+ assert test_var.val == test_value
+ assert Fernet(key2).decrypt(test_var._val.encode()) == test_value.encode()
def test_variable_set_get_round_trip(self):
Variable.set("tested_var_set_id", "Monday morning breakfast")
- self.assertEqual("Monday morning breakfast", Variable.get("tested_var_set_id"))
+ assert "Monday morning breakfast" == Variable.get("tested_var_set_id")
def test_variable_set_with_env_variable(self):
Variable.set("key", "db-value")
with self.assertLogs(variable.log) as log_context:
with mock.patch.dict('os.environ', AIRFLOW_VAR_KEY="env-value"):
Variable.set("key", "new-db-value")
- self.assertEqual("env-value", Variable.get("key"))
- self.assertEqual("new-db-value", Variable.get("key"))
-
- self.assertEqual(
- log_context.records[0].message,
- (
- 'You have the environment variable AIRFLOW_VAR_KEY defined, which takes precedence over '
- 'reading from the database. The value will be saved, but to read it you have to delete '
- 'the environment variable.'
- ),
+ assert "env-value" == Variable.get("key")
+ assert "new-db-value" == Variable.get("key")
+
+ assert log_context.records[0].message == (
+ 'You have the environment variable AIRFLOW_VAR_KEY defined, which takes precedence over '
+ 'reading from the database. The value will be saved, but to read it you have to delete '
+ 'the environment variable.'
)
def test_variable_set_get_round_trip_json(self):
value = {"a": 17, "b": 47}
Variable.set("tested_var_set_id", value, serialize_json=True)
- self.assertEqual(value, Variable.get("tested_var_set_id", deserialize_json=True))
+ assert value == Variable.get("tested_var_set_id", deserialize_json=True)
def test_variable_set_existing_value_to_blank(self):
test_value = 'Some value'
test_key = 'test_key'
Variable.set(test_key, test_value)
Variable.set(test_key, '')
- self.assertEqual('', Variable.get('test_key'))
+ assert '' == Variable.get('test_key')
def test_get_non_existing_var_should_return_default(self):
default_value = "some default val"
- self.assertEqual(default_value, Variable.get("thisIdDoesNotExist", default_var=default_value))
+ assert default_value == Variable.get("thisIdDoesNotExist", default_var=default_value)
def test_get_non_existing_var_should_raise_key_error(self):
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
Variable.get("thisIdDoesNotExist")
def test_get_non_existing_var_with_none_default_should_return_none(self):
- self.assertIsNone(Variable.get("thisIdDoesNotExist", default_var=None))
+ assert Variable.get("thisIdDoesNotExist", default_var=None) is None
def test_get_non_existing_var_should_not_deserialize_json_default(self):
default_value = "}{ this is a non JSON default }{"
- self.assertEqual(
- default_value,
- Variable.get("thisIdDoesNotExist", default_var=default_value, deserialize_json=True),
+ assert default_value == Variable.get(
+ "thisIdDoesNotExist", default_var=default_value, deserialize_json=True
)
def test_variable_setdefault_round_trip(self):
key = "tested_var_setdefault_1_id"
value = "Monday morning breakfast in Paris"
Variable.setdefault(key, value)
- self.assertEqual(value, Variable.get(key))
+ assert value == Variable.get(key)
def test_variable_setdefault_round_trip_json(self):
key = "tested_var_setdefault_2_id"
value = {"city": 'Paris', "Happiness": True}
Variable.setdefault(key, value, deserialize_json=True)
- self.assertEqual(value, Variable.get(key, deserialize_json=True))
+ assert value == Variable.get(key, deserialize_json=True)
def test_variable_setdefault_existing_json(self):
key = "tested_var_setdefault_2_id"
@@ -155,8 +152,8 @@ def test_variable_setdefault_existing_json(self):
Variable.set(key, value, serialize_json=True)
val = Variable.setdefault(key, value, deserialize_json=True)
# Check the returned value, and the stored value are handled correctly.
- self.assertEqual(value, val)
- self.assertEqual(value, Variable.get(key, deserialize_json=True))
+ assert value == val
+ assert value == Variable.get(key, deserialize_json=True)
def test_variable_delete(self):
key = "tested_var_delete"
@@ -164,14 +161,14 @@ def test_variable_delete(self):
# No-op if the variable doesn't exist
Variable.delete(key)
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
Variable.get(key)
# Set the variable
Variable.set(key, value)
- self.assertEqual(value, Variable.get(key))
+ assert value == Variable.get(key)
# Delete the variable
Variable.delete(key)
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
Variable.get(key)
diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py
index c5f523a46e3f5..0339edbb90fec 100644
--- a/tests/models/test_xcom.py
+++ b/tests/models/test_xcom.py
@@ -18,6 +18,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow import settings
from airflow.configuration import conf
from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend
@@ -74,7 +76,7 @@ def test_xcom_disable_pickle_type(self):
.value
)
- self.assertEqual(ret_value, json_obj)
+ assert ret_value == json_obj
session = settings.Session()
ret_value = (
@@ -89,7 +91,7 @@ def test_xcom_disable_pickle_type(self):
.value
)
- self.assertEqual(ret_value, json_obj)
+ assert ret_value == json_obj
@conf_vars({("core", "enable_xcom_pickling"): "False"})
def test_xcom_get_one_disable_pickle_type(self):
@@ -102,7 +104,7 @@ def test_xcom_get_one_disable_pickle_type(self):
ret_value = XCom.get_one(key=key, dag_id=dag_id, task_id=task_id, execution_date=execution_date)
- self.assertEqual(ret_value, json_obj)
+ assert ret_value == json_obj
session = settings.Session()
ret_value = (
@@ -117,7 +119,7 @@ def test_xcom_get_one_disable_pickle_type(self):
.value
)
- self.assertEqual(ret_value, json_obj)
+ assert ret_value == json_obj
@conf_vars({("core", "enable_xcom_pickling"): "True"})
def test_xcom_enable_pickle_type(self):
@@ -134,7 +136,7 @@ def test_xcom_enable_pickle_type(self):
.value
)
- self.assertEqual(ret_value, json_obj)
+ assert ret_value == json_obj
session = settings.Session()
ret_value = (
@@ -149,7 +151,7 @@ def test_xcom_enable_pickle_type(self):
.value
)
- self.assertEqual(ret_value, json_obj)
+ assert ret_value == json_obj
@conf_vars({("core", "enable_xcom_pickling"): "True"})
def test_xcom_get_one_enable_pickle_type(self):
@@ -162,7 +164,7 @@ def test_xcom_get_one_enable_pickle_type(self):
ret_value = XCom.get_one(key=key, dag_id=dag_id, task_id=task_id, execution_date=execution_date)
- self.assertEqual(ret_value, json_obj)
+ assert ret_value == json_obj
session = settings.Session()
ret_value = (
@@ -177,7 +179,7 @@ def test_xcom_get_one_enable_pickle_type(self):
.value
)
- self.assertEqual(ret_value, json_obj)
+ assert ret_value == json_obj
def test_xcom_deserialize_with_json_to_pickle_switch(self):
json_obj = {"key": "value"}
@@ -192,7 +194,7 @@ def test_xcom_deserialize_with_json_to_pickle_switch(self):
with conf_vars({("core", "enable_xcom_pickling"): "True"}):
ret_value = XCom.get_one(key=key, dag_id=dag_id, task_id=task_id, execution_date=execution_date)
- self.assertEqual(ret_value, json_obj)
+ assert ret_value == json_obj
def test_xcom_deserialize_with_pickle_to_json_switch(self):
json_obj = {"key": "value"}
@@ -207,7 +209,7 @@ def test_xcom_deserialize_with_pickle_to_json_switch(self):
with conf_vars({("core", "enable_xcom_pickling"): "False"}):
ret_value = XCom.get_one(key=key, dag_id=dag_id, task_id=task_id, execution_date=execution_date)
- self.assertEqual(ret_value, json_obj)
+ assert ret_value == json_obj
@conf_vars({("core", "xcom_enable_pickling"): "False"})
def test_xcom_disable_pickle_type_fail_on_non_json(self):
@@ -215,15 +217,14 @@ class PickleRce:
def __reduce__(self):
return os.system, ("ls -alt",)
- self.assertRaises(
- TypeError,
- XCom.set,
- key="xcom_test3",
- value=PickleRce(),
- dag_id="test_dag3",
- task_id="test_task3",
- execution_date=timezone.utcnow(),
- )
+ with pytest.raises(TypeError):
+ XCom.set(
+ key="xcom_test3",
+ value=PickleRce(),
+ dag_id="test_dag3",
+ task_id="test_task3",
+ execution_date=timezone.utcnow(),
+ )
@conf_vars({("core", "xcom_enable_pickling"): "True"})
def test_xcom_get_many(self):
@@ -242,7 +243,7 @@ def test_xcom_get_many(self):
results = XCom.get_many(key=key, execution_date=execution_date)
for result in results:
- self.assertEqual(result.value, json_obj)
+ assert result.value == json_obj
@mock.patch("airflow.models.xcom.XCom.orm_deserialize_value")
def test_xcom_init_on_load_uses_orm_deserialize_value(self, mock_orm_deserialize):
diff --git a/tests/operators/test_bash.py b/tests/operators/test_bash.py
index 9c0cca18d817e..871d47506ca97 100644
--- a/tests/operators/test_bash.py
+++ b/tests/operators/test_bash.py
@@ -22,6 +22,8 @@
from tempfile import NamedTemporaryFile
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import DagRun
from airflow.models.dag import DAG
@@ -78,24 +80,24 @@ def test_echo_env_variables(self):
with open(tmp_file.name) as file:
output = ''.join(file.readlines())
- self.assertIn('MY_PATH_TO_AIRFLOW_HOME', output)
+ assert 'MY_PATH_TO_AIRFLOW_HOME' in output
# exported in run-tests as part of PYTHONPATH
- self.assertIn('AWESOME_PYTHONPATH', output)
- self.assertIn('bash_op_test', output)
- self.assertIn('echo_env_vars', output)
- self.assertIn(DEFAULT_DATE.isoformat(), output)
- self.assertIn(DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE), output)
+ assert 'AWESOME_PYTHONPATH' in output
+ assert 'bash_op_test' in output
+ assert 'echo_env_vars' in output
+ assert DEFAULT_DATE.isoformat() in output
+ assert DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE) in output
def test_return_value(self):
bash_operator = BashOperator(bash_command='echo "stdout"', task_id='test_return_value', dag=None)
return_value = bash_operator.execute(context={})
- self.assertEqual(return_value, 'stdout')
+ assert return_value == 'stdout'
def test_raise_exception_on_non_zero_exit_code(self):
bash_operator = BashOperator(bash_command='exit 42', task_id='test_return_value', dag=None)
- with self.assertRaisesRegex(
- AirflowException, "Bash command failed\\. The command returned a non-zero exit code\\."
+ with pytest.raises(
+ AirflowException, match="Bash command failed\\. The command returned a non-zero exit code\\."
):
bash_operator.execute(context={})
@@ -104,12 +106,12 @@ def test_task_retries(self):
bash_command='echo "stdout"', task_id='test_task_retries', retries=2, dag=None
)
- self.assertEqual(bash_operator.retries, 2)
+ assert bash_operator.retries == 2
def test_default_retries(self):
bash_operator = BashOperator(bash_command='echo "stdout"', task_id='test_default_retries', dag=None)
- self.assertEqual(bash_operator.retries, 0)
+ assert bash_operator.retries == 0
@mock.patch.dict('os.environ', clear=True)
@mock.patch(
diff --git a/tests/operators/test_branch_operator.py b/tests/operators/test_branch_operator.py
index 8ae89b0ed83ed..d3725340234de 100644
--- a/tests/operators/test_branch_operator.py
+++ b/tests/operators/test_branch_operator.py
@@ -83,12 +83,12 @@ def test_without_dag_run(self):
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'branch_1':
# should exist with state None
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
elif ti.task_id == 'branch_2':
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
else:
raise Exception
@@ -115,7 +115,7 @@ def test_branch_list_without_dag_run(self):
for ti in tis:
if ti.task_id in expected:
- self.assertEqual(ti.state, expected[ti.task_id])
+ assert ti.state == expected[ti.task_id]
else:
raise Exception
@@ -137,11 +137,11 @@ def test_with_dag_run(self):
tis = dagrun.get_task_instances()
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'branch_1':
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
elif ti.task_id == 'branch_2':
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
else:
raise Exception
@@ -163,10 +163,10 @@ def test_with_skip_in_branch_downstream_dependencies(self):
tis = dagrun.get_task_instances()
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'branch_1':
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
elif ti.task_id == 'branch_2':
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
else:
raise Exception
diff --git a/tests/operators/test_latest_only_operator.py b/tests/operators/test_latest_only_operator.py
index 30ac14a347781..ee0a0a4f0044c 100644
--- a/tests/operators/test_latest_only_operator.py
+++ b/tests/operators/test_latest_only_operator.py
@@ -107,47 +107,35 @@ def test_skipping_non_latest(self):
latest_instances = get_task_instances('latest')
exec_date_to_latest_state = {ti.execution_date: ti.state for ti in latest_instances}
- self.assertEqual(
- {
- timezone.datetime(2016, 1, 1): 'success',
- timezone.datetime(2016, 1, 1, 12): 'success',
- timezone.datetime(2016, 1, 2): 'success',
- },
- exec_date_to_latest_state,
- )
+ assert {
+ timezone.datetime(2016, 1, 1): 'success',
+ timezone.datetime(2016, 1, 1, 12): 'success',
+ timezone.datetime(2016, 1, 2): 'success',
+ } == exec_date_to_latest_state
downstream_instances = get_task_instances('downstream')
exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances}
- self.assertEqual(
- {
- timezone.datetime(2016, 1, 1): 'skipped',
- timezone.datetime(2016, 1, 1, 12): 'skipped',
- timezone.datetime(2016, 1, 2): 'success',
- },
- exec_date_to_downstream_state,
- )
+ assert {
+ timezone.datetime(2016, 1, 1): 'skipped',
+ timezone.datetime(2016, 1, 1, 12): 'skipped',
+ timezone.datetime(2016, 1, 2): 'success',
+ } == exec_date_to_downstream_state
downstream_instances = get_task_instances('downstream_2')
exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances}
- self.assertEqual(
- {
- timezone.datetime(2016, 1, 1): None,
- timezone.datetime(2016, 1, 1, 12): None,
- timezone.datetime(2016, 1, 2): 'success',
- },
- exec_date_to_downstream_state,
- )
+ assert {
+ timezone.datetime(2016, 1, 1): None,
+ timezone.datetime(2016, 1, 1, 12): None,
+ timezone.datetime(2016, 1, 2): 'success',
+ } == exec_date_to_downstream_state
downstream_instances = get_task_instances('downstream_3')
exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances}
- self.assertEqual(
- {
- timezone.datetime(2016, 1, 1): 'success',
- timezone.datetime(2016, 1, 1, 12): 'success',
- timezone.datetime(2016, 1, 2): 'success',
- },
- exec_date_to_downstream_state,
- )
+ assert {
+ timezone.datetime(2016, 1, 1): 'success',
+ timezone.datetime(2016, 1, 1, 12): 'success',
+ timezone.datetime(2016, 1, 2): 'success',
+ } == exec_date_to_downstream_state
def test_not_skipping_external(self):
latest_task = LatestOnlyOperator(task_id='latest', dag=self.dag)
@@ -187,33 +175,24 @@ def test_not_skipping_external(self):
latest_instances = get_task_instances('latest')
exec_date_to_latest_state = {ti.execution_date: ti.state for ti in latest_instances}
- self.assertEqual(
- {
- timezone.datetime(2016, 1, 1): 'success',
- timezone.datetime(2016, 1, 1, 12): 'success',
- timezone.datetime(2016, 1, 2): 'success',
- },
- exec_date_to_latest_state,
- )
+ assert {
+ timezone.datetime(2016, 1, 1): 'success',
+ timezone.datetime(2016, 1, 1, 12): 'success',
+ timezone.datetime(2016, 1, 2): 'success',
+ } == exec_date_to_latest_state
downstream_instances = get_task_instances('downstream')
exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances}
- self.assertEqual(
- {
- timezone.datetime(2016, 1, 1): 'success',
- timezone.datetime(2016, 1, 1, 12): 'success',
- timezone.datetime(2016, 1, 2): 'success',
- },
- exec_date_to_downstream_state,
- )
+ assert {
+ timezone.datetime(2016, 1, 1): 'success',
+ timezone.datetime(2016, 1, 1, 12): 'success',
+ timezone.datetime(2016, 1, 2): 'success',
+ } == exec_date_to_downstream_state
downstream_instances = get_task_instances('downstream_2')
exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances}
- self.assertEqual(
- {
- timezone.datetime(2016, 1, 1): 'success',
- timezone.datetime(2016, 1, 1, 12): 'success',
- timezone.datetime(2016, 1, 2): 'success',
- },
- exec_date_to_downstream_state,
- )
+ assert {
+ timezone.datetime(2016, 1, 1): 'success',
+ timezone.datetime(2016, 1, 1, 12): 'success',
+ timezone.datetime(2016, 1, 2): 'success',
+ } == exec_date_to_downstream_state
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index bd704bfb666ac..14ebdf00b6b21 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -110,14 +110,14 @@ def clear_run(self):
self.run = False
def _assert_calls_equal(self, first, second):
- self.assertIsInstance(first, Call)
- self.assertIsInstance(second, Call)
- self.assertTupleEqual(first.args, second.args)
+ assert isinstance(first, Call)
+ assert isinstance(second, Call)
+ assert first.args == second.args
# eliminate context (conf, dag_run, task_instance, etc.)
test_args = ["an_int", "a_date", "a_templated_string"]
first.kwargs = {key: value for (key, value) in first.kwargs.items() if key in test_args}
second.kwargs = {key: value for (key, value) in second.kwargs.items() if key in test_args}
- self.assertDictEqual(first.kwargs, second.kwargs)
+ assert first.kwargs == second.kwargs
class TestPythonOperator(TestPythonBase):
@@ -130,18 +130,18 @@ def is_run(self):
def test_python_operator_run(self):
"""Tests that the python callable is invoked on task run."""
task = PythonOperator(python_callable=self.do_run, task_id='python_operator', dag=self.dag)
- self.assertFalse(self.is_run())
+ assert not self.is_run()
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- self.assertTrue(self.is_run())
+ assert self.is_run()
def test_python_operator_python_callable_is_callable(self):
"""Tests that PythonOperator will only instantiate if
the python_callable argument is callable."""
not_callable = {}
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
PythonOperator(python_callable=not_callable, task_id='python_operator', dag=self.dag)
not_callable = None
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
PythonOperator(python_callable=not_callable, task_id='python_operator', dag=self.dag)
def test_python_callable_arguments_are_templatized(self):
@@ -171,7 +171,7 @@ def test_python_callable_arguments_are_templatized(self):
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
ds_templated = DEFAULT_DATE.date().isoformat()
- self.assertEqual(1, len(recorded_calls))
+ assert 1 == len(recorded_calls)
self._assert_calls_equal(
recorded_calls[0],
Call(
@@ -207,7 +207,7 @@ def test_python_callable_keyword_arguments_are_templatized(self):
)
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- self.assertEqual(1, len(recorded_calls))
+ assert 1 == len(recorded_calls)
self._assert_calls_equal(
recorded_calls[0],
Call(
@@ -229,11 +229,9 @@ def test_python_operator_shallow_copy_attr(self):
)
new_task = copy.deepcopy(original_task)
# shallow copy op_kwargs
- self.assertEqual(
- id(original_task.op_kwargs['certain_attrs']), id(new_task.op_kwargs['certain_attrs'])
- )
+ assert id(original_task.op_kwargs['certain_attrs']) == id(new_task.op_kwargs['certain_attrs'])
# shallow copy python_callable
- self.assertEqual(id(original_task.python_callable), id(new_task.python_callable))
+ assert id(original_task.python_callable) == id(new_task.python_callable)
def test_conflicting_kwargs(self):
self.dag.create_dagrun(
@@ -254,9 +252,9 @@ def func(dag):
task_id='python_operator', op_args=[1], python_callable=func, dag=self.dag
)
- with self.assertRaises(ValueError) as context:
+ with pytest.raises(ValueError) as ctx:
python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- self.assertTrue('dag' in context.exception, "'dag' not found in the exception")
+ assert 'dag' in str(ctx.value), "'dag' not found in the exception"
def test_provide_context_does_not_fail(self):
"""
@@ -271,8 +269,8 @@ def test_provide_context_does_not_fail(self):
)
def func(custom, dag):
- self.assertEqual(1, custom, "custom should be 1")
- self.assertIsNotNone(dag, "dag should be set")
+ assert 1 == custom, "custom should be 1"
+ assert dag is not None, "dag should be set"
python_operator = PythonOperator(
task_id='python_operator',
@@ -293,8 +291,8 @@ def test_context_with_conflicting_op_args(self):
)
def func(custom, dag):
- self.assertEqual(1, custom, "custom should be 1")
- self.assertIsNotNone(dag, "dag should be set")
+ assert 1 == custom, "custom should be 1"
+ assert dag is not None, "dag should be set"
python_operator = PythonOperator(
task_id='python_operator', op_kwargs={'custom': 1}, python_callable=func, dag=self.dag
@@ -312,7 +310,7 @@ def test_context_with_kwargs(self):
def func(**context):
# check if context is being set
- self.assertGreater(len(context), 0, "Context has not been injected")
+ assert len(context) > 0, "Context has not been injected"
python_operator = PythonOperator(
task_id='python_operator', op_kwargs={'custom': 1}, python_callable=func, dag=self.dag
@@ -758,12 +756,12 @@ def test_without_dag_run(self):
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'branch_1':
# should exist with state None
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
elif ti.task_id == 'branch_2':
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
else:
raise ValueError(f'Invalid task id {ti.task_id} found!')
@@ -792,7 +790,7 @@ def test_branch_list_without_dag_run(self):
for ti in tis:
if ti.task_id in expected:
- self.assertEqual(ti.state, expected[ti.task_id])
+ assert ti.state == expected[ti.task_id]
else:
raise ValueError(f'Invalid task id {ti.task_id} found!')
@@ -817,11 +815,11 @@ def test_with_dag_run(self):
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'branch_1':
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
elif ti.task_id == 'branch_2':
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
else:
raise ValueError(f'Invalid task id {ti.task_id} found!')
@@ -846,11 +844,11 @@ def test_with_skip_in_branch_downstream_dependencies(self):
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'branch_1':
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
elif ti.task_id == 'branch_2':
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
else:
raise ValueError(f'Invalid task id {ti.task_id} found!')
@@ -875,11 +873,11 @@ def test_with_skip_in_branch_downstream_dependencies2(self):
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'branch_1':
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
elif ti.task_id == 'branch_2':
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
else:
raise ValueError(f'Invalid task id {ti.task_id} found!')
@@ -904,7 +902,7 @@ def test_xcom_push(self):
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(ti.xcom_pull(task_ids='make_choice'), 'branch_1')
+ assert ti.xcom_pull(task_ids='make_choice') == 'branch_1'
def test_clear_skipped_downstream_task(self):
"""
@@ -933,11 +931,11 @@ def test_clear_skipped_downstream_task(self):
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'branch_1':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'branch_2':
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
else:
raise ValueError(f'Invalid task id {ti.task_id} found!')
@@ -954,11 +952,11 @@ def test_clear_skipped_downstream_task(self):
# Check if the states are correct after children tasks are cleared.
for ti in dr.get_task_instances():
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'branch_1':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'branch_2':
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
else:
raise ValueError(f'Invalid task id {ti.task_id} found!')
@@ -1003,12 +1001,12 @@ def test_without_dag_run(self):
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'upstream':
# should not exist
raise ValueError(f'Invalid task id {ti.task_id} found!')
elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
else:
raise ValueError(f'Invalid task id {ti.task_id} found!')
@@ -1018,12 +1016,12 @@ def test_without_dag_run(self):
short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'upstream':
# should not exist
raise ValueError(f'Invalid task id {ti.task_id} found!')
elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
else:
raise ValueError(f'Invalid task id {ti.task_id} found!')
@@ -1055,14 +1053,14 @@ def test_with_dag_run(self):
short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 4)
+ assert len(tis) == 4
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'upstream':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
else:
raise ValueError(f'Invalid task id {ti.task_id} found!')
@@ -1073,14 +1071,14 @@ def test_with_dag_run(self):
short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 4)
+ assert len(tis) == 4
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'upstream':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
else:
raise ValueError(f'Invalid task id {ti.task_id} found!')
@@ -1115,9 +1113,9 @@ def test_clear_skipped_downstream_task(self):
for ti in tis:
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'downstream':
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
else:
raise ValueError(f'Invalid task id {ti.task_id} found!')
@@ -1131,9 +1129,9 @@ def test_clear_skipped_downstream_task(self):
# Check if the states are correct.
for ti in dr.get_task_instances():
if ti.task_id == 'make_choice':
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == 'downstream':
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
else:
raise ValueError(f'Invalid task id {ti.task_id} found!')
@@ -1227,7 +1225,7 @@ def test_fail(self):
def f():
raise Exception
- with self.assertRaises(CalledProcessError):
+ with pytest.raises(CalledProcessError):
self._run_as_operator(f)
def test_python_2(self):
@@ -1272,7 +1270,7 @@ def test_wrong_python_op_args(self):
def f():
pass
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self._run_as_operator(f, python_version=version, op_args=[1])
def test_without_dill(self):
@@ -1306,7 +1304,7 @@ def f():
self._run_as_operator(f)
def test_lambda(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
PythonVirtualenvOperator(python_callable=lambda x: 4, task_id='task', dag=self.dag)
def test_nonimported_as_arg(self):
diff --git a/tests/operators/test_sql.py b/tests/operators/test_sql.py
index e73d2f21ba6fb..3e202909779bb 100644
--- a/tests/operators/test_sql.py
+++ b/tests/operators/test_sql.py
@@ -71,14 +71,14 @@ class TestCheckOperator(unittest.TestCase):
def test_execute_no_records(self, mock_get_db_hook):
mock_get_db_hook.return_value.get_first.return_value = []
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
CheckOperator(sql="sql").execute()
@mock.patch.object(CheckOperator, "get_db_hook")
def test_execute_not_all_records_are_true(self, mock_get_db_hook):
mock_get_db_hook.return_value.get_first.return_value = ["data", ""]
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
CheckOperator(sql="sql").execute()
@@ -105,8 +105,8 @@ def test_pass_value_template_string(self):
operator.render_template_fields({"ds": pass_value_str})
- self.assertEqual(operator.task_id, self.task_id)
- self.assertEqual(operator.pass_value, pass_value_str)
+ assert operator.task_id == self.task_id
+ assert operator.pass_value == pass_value_str
def test_pass_value_template_string_float(self):
pass_value_float = 4.0
@@ -114,8 +114,8 @@ def test_pass_value_template_string_float(self):
operator.render_template_fields({})
- self.assertEqual(operator.task_id, self.task_id)
- self.assertEqual(operator.pass_value, str(pass_value_float))
+ assert operator.task_id == self.task_id
+ assert operator.pass_value == str(pass_value_float)
@mock.patch.object(ValueCheckOperator, "get_db_hook")
def test_execute_pass(self, mock_get_db_hook):
@@ -137,7 +137,7 @@ def test_execute_fail(self, mock_get_db_hook):
operator = self._construct_operator("select value from tab1 limit 1;", 5, 1)
- with self.assertRaisesRegex(AirflowException, "Tolerance:100.0%"):
+ with pytest.raises(AirflowException, match="Tolerance:100.0%"):
operator.execute()
@@ -152,7 +152,7 @@ def _construct_operator(self, table, metric_thresholds, ratio_formula, ignore_ze
)
def test_invalid_ratio_formula(self):
- with self.assertRaisesRegex(AirflowException, "Invalid diff_method"):
+ with pytest.raises(AirflowException, match="Invalid diff_method"):
self._construct_operator(
table="test_table",
metric_thresholds={
@@ -177,7 +177,7 @@ def test_execute_not_ignore_zero(self, mock_get_db_hook):
ignore_zero=False,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
operator.execute()
@mock.patch.object(IntervalCheckOperator, "get_db_hook")
@@ -224,7 +224,7 @@ def returned_row():
ignore_zero=True,
)
- with self.assertRaisesRegex(AirflowException, "f0, f1, f2"):
+ with pytest.raises(AirflowException, match="f0, f1, f2"):
operator.execute()
@mock.patch.object(IntervalCheckOperator, "get_db_hook")
@@ -254,7 +254,7 @@ def returned_row():
ignore_zero=True,
)
- with self.assertRaisesRegex(AirflowException, "f0, f1"):
+ with pytest.raises(AirflowException, match="f0, f1"):
operator.execute()
@@ -288,7 +288,7 @@ def test_fail_min_value_max_value(self, mock_get_db_hook):
operator = self._construct_operator("Select avg(val) from table1 limit 1", 20, 100)
- with self.assertRaisesRegex(AirflowException, "10.*20.0.*100.0"):
+ with pytest.raises(AirflowException, match="10.*20.0.*100.0"):
operator.execute()
@mock.patch.object(ThresholdCheckOperator, "get_db_hook")
@@ -309,7 +309,7 @@ def test_fail_min_sql_max_sql(self, mock_get_db_hook):
operator = self._construct_operator("Select 10", "Select 20", "Select 100")
- with self.assertRaisesRegex(AirflowException, "10.*20.*100"):
+ with pytest.raises(AirflowException, match="10.*20.*100"):
operator.execute()
@mock.patch.object(ThresholdCheckOperator, "get_db_hook")
@@ -330,7 +330,7 @@ def test_fail_min_sql_max_value(self, mock_get_db_hook):
operator = self._construct_operator("Select 155", "Select 45", 100)
- with self.assertRaisesRegex(AirflowException, "155.*45.*100.0"):
+ with pytest.raises(AirflowException, match="155.*45.*100.0"):
operator.execute()
@@ -376,7 +376,7 @@ def test_unsupported_conn_type(self):
dag=self.dag,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_invalid_conn(self):
@@ -390,7 +390,7 @@ def test_invalid_conn(self):
dag=self.dag,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_invalid_follow_task_true(self):
@@ -404,7 +404,7 @@ def test_invalid_follow_task_true(self):
dag=self.dag,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_invalid_follow_task_false(self):
@@ -418,7 +418,7 @@ def test_invalid_follow_task_false(self):
dag=self.dag,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@pytest.mark.backend("mysql")
@@ -480,11 +480,11 @@ def test_branch_single_value_with_dag_run(self, mock_hook):
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == "make_choice":
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == "branch_1":
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
elif ti.task_id == "branch_2":
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
else:
raise ValueError(f"Invalid task id {ti.task_id} found!")
@@ -522,11 +522,11 @@ def test_branch_true_with_dag_run(self, mock_hook):
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == "make_choice":
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == "branch_1":
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
elif ti.task_id == "branch_2":
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
else:
raise ValueError(f"Invalid task id {ti.task_id} found!")
@@ -564,11 +564,11 @@ def test_branch_false_with_dag_run(self, mock_hook):
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == "make_choice":
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == "branch_1":
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
elif ti.task_id == "branch_2":
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
else:
raise ValueError(f"Invalid task id {ti.task_id} found!")
@@ -606,13 +606,13 @@ def test_branch_list_with_dag_run(self, mock_hook):
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == "make_choice":
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == "branch_1":
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
elif ti.task_id == "branch_2":
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
elif ti.task_id == "branch_3":
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
else:
raise ValueError(f"Invalid task id {ti.task_id} found!")
@@ -644,7 +644,7 @@ def test_invalid_query_result_with_dag_run(self, mock_hook):
mock_get_records.return_value = ["Invalid Value"]
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@mock.patch("airflow.operators.sql.BaseHook")
@@ -681,11 +681,11 @@ def test_with_skip_in_branch_downstream_dependencies(self, mock_hook):
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == "make_choice":
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == "branch_1":
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
elif ti.task_id == "branch_2":
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
else:
raise ValueError(f"Invalid task id {ti.task_id} found!")
@@ -723,10 +723,10 @@ def test_with_skip_in_branch_downstream_dependencies2(self, mock_hook):
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == "make_choice":
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
elif ti.task_id == "branch_1":
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
elif ti.task_id == "branch_2":
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
else:
raise ValueError(f"Invalid task id {ti.task_id} found!")
diff --git a/tests/operators/test_subdag_operator.py b/tests/operators/test_subdag_operator.py
index 86db2a3b68556..823b99c1083cf 100644
--- a/tests/operators/test_subdag_operator.py
+++ b/tests/operators/test_subdag_operator.py
@@ -20,6 +20,7 @@
from unittest import mock
from unittest.mock import Mock
+import pytest
from parameterized import parameterized
import airflow
@@ -62,9 +63,12 @@ def test_subdag_name(self):
subdag_bad3 = DAG('bad.bad', default_args=default_args)
SubDagOperator(task_id='test', dag=dag, subdag=subdag_good)
- self.assertRaises(AirflowException, SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad1)
- self.assertRaises(AirflowException, SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad2)
- self.assertRaises(AirflowException, SubDagOperator, task_id='test', dag=dag, subdag=subdag_bad3)
+ with pytest.raises(AirflowException):
+ SubDagOperator(task_id='test', dag=dag, subdag=subdag_bad1)
+ with pytest.raises(AirflowException):
+ SubDagOperator(task_id='test', dag=dag, subdag=subdag_bad2)
+ with pytest.raises(AirflowException):
+ SubDagOperator(task_id='test', dag=dag, subdag=subdag_bad3)
def test_subdag_in_context_manager(self):
"""
@@ -74,8 +78,8 @@ def test_subdag_in_context_manager(self):
subdag = DAG('parent.test', default_args=default_args)
op = SubDagOperator(task_id='test', subdag=subdag)
- self.assertEqual(op.dag, dag)
- self.assertEqual(op.subdag, subdag)
+ assert op.dag == dag
+ assert op.subdag == subdag
def test_subdag_pools(self):
"""
@@ -93,9 +97,8 @@ def test_subdag_pools(self):
DummyOperator(task_id='dummy', dag=subdag, pool='test_pool_1')
- self.assertRaises(
- AirflowException, SubDagOperator, task_id='child', dag=dag, subdag=subdag, pool='test_pool_1'
- )
+ with pytest.raises(AirflowException):
+ SubDagOperator(task_id='child', dag=dag, subdag=subdag, pool='test_pool_1')
# recreate dag because failed subdagoperator was already added
dag = DAG('parent', default_args=default_args)
@@ -124,7 +127,7 @@ def test_subdag_pools_no_possible_conflict(self):
mock_session = Mock()
SubDagOperator(task_id='child', dag=dag, subdag=subdag, pool='test_pool_1', session=mock_session)
- self.assertFalse(mock_session.query.called)
+ assert not mock_session.query.called
session.delete(pool_1)
session.delete(pool_10)
@@ -157,7 +160,7 @@ def test_execute_create_dagrun_wait_until_success(self):
external_trigger=True,
)
- self.assertEqual(3, len(subdag_task._get_dagrun.mock_calls))
+ assert 3 == len(subdag_task._get_dagrun.mock_calls)
def test_execute_create_dagrun_with_conf(self):
"""
@@ -187,7 +190,7 @@ def test_execute_create_dagrun_with_conf(self):
external_trigger=True,
)
- self.assertEqual(3, len(subdag_task._get_dagrun.mock_calls))
+ assert 3 == len(subdag_task._get_dagrun.mock_calls)
def test_execute_dagrun_failed(self):
"""
@@ -203,7 +206,7 @@ def test_execute_dagrun_failed(self):
subdag_task._get_dagrun = Mock()
subdag_task._get_dagrun.side_effect = [None, self.dag_run_failed, self.dag_run_failed]
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
subdag_task.pre_execute(context={'execution_date': DEFAULT_DATE})
subdag_task.execute(context={'execution_date': DEFAULT_DATE})
subdag_task.post_execute(context={'execution_date': DEFAULT_DATE})
@@ -225,7 +228,7 @@ def test_execute_skip_if_dagrun_success(self):
subdag_task.post_execute(context={'execution_date': DEFAULT_DATE})
subdag.create_dagrun.assert_not_called()
- self.assertEqual(3, len(subdag_task._get_dagrun.mock_calls))
+ assert 3 == len(subdag_task._get_dagrun.mock_calls)
def test_rerun_failed_subdag(self):
"""
@@ -256,10 +259,10 @@ def test_rerun_failed_subdag(self):
subdag_task._reset_dag_run_and_task_instances(sub_dagrun, execution_date=DEFAULT_DATE)
dummy_task_instance.refresh_from_db()
- self.assertEqual(dummy_task_instance.state, State.NONE)
+ assert dummy_task_instance.state == State.NONE
sub_dagrun.refresh_from_db()
- self.assertEqual(sub_dagrun.state, State.RUNNING)
+ assert sub_dagrun.state == State.RUNNING
@parameterized.expand(
[
diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py
index 42bc2a2c1b2b8..c17d43c032311 100644
--- a/tests/operators/test_trigger_dagrun.py
+++ b/tests/operators/test_trigger_dagrun.py
@@ -21,6 +21,8 @@
from datetime import datetime
from unittest import TestCase
+import pytest
+
from airflow.exceptions import AirflowException, DagRunAlreadyExists
from airflow.models import DAG, DagBag, DagModel, DagRun, Log, TaskInstance
from airflow.models.serialized_dag import SerializedDagModel
@@ -81,8 +83,8 @@ def test_trigger_dagrun(self):
with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
- self.assertEqual(len(dagruns), 1)
- self.assertTrue(dagruns[0].external_trigger)
+ assert len(dagruns) == 1
+ assert dagruns[0].external_trigger
def test_trigger_dagrun_with_execution_date(self):
"""Test TriggerDagRunOperator with custom execution_date."""
@@ -97,9 +99,9 @@ def test_trigger_dagrun_with_execution_date(self):
with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
- self.assertEqual(len(dagruns), 1)
- self.assertTrue(dagruns[0].external_trigger)
- self.assertEqual(dagruns[0].execution_date, utc_now)
+ assert len(dagruns) == 1
+ assert dagruns[0].external_trigger
+ assert dagruns[0].execution_date == utc_now
def test_trigger_dagrun_twice(self):
"""Test TriggerDagRunOperator with custom execution_date."""
@@ -127,9 +129,9 @@ def test_trigger_dagrun_twice(self):
task.execute(None)
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
- self.assertEqual(len(dagruns), 1)
- self.assertTrue(dagruns[0].external_trigger)
- self.assertEqual(dagruns[0].execution_date, utc_now)
+ assert len(dagruns) == 1
+ assert dagruns[0].external_trigger
+ assert dagruns[0].execution_date == utc_now
def test_trigger_dagrun_with_templated_execution_date(self):
"""Test TriggerDagRunOperator with templated execution_date."""
@@ -143,9 +145,9 @@ def test_trigger_dagrun_with_templated_execution_date(self):
with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
- self.assertEqual(len(dagruns), 1)
- self.assertTrue(dagruns[0].external_trigger)
- self.assertEqual(dagruns[0].execution_date, DEFAULT_DATE)
+ assert len(dagruns) == 1
+ assert dagruns[0].external_trigger
+ assert dagruns[0].execution_date == DEFAULT_DATE
def test_trigger_dagrun_operator_conf(self):
"""Test passing conf to the triggered DagRun."""
@@ -159,8 +161,8 @@ def test_trigger_dagrun_operator_conf(self):
with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
- self.assertEqual(len(dagruns), 1)
- self.assertTrue(dagruns[0].conf, {"foo": "bar"})
+ assert len(dagruns) == 1
+ assert dagruns[0].conf, {"foo": "bar"}
def test_trigger_dagrun_operator_templated_conf(self):
"""Test passing a templated conf to the triggered DagRun."""
@@ -174,8 +176,8 @@ def test_trigger_dagrun_operator_templated_conf(self):
with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
- self.assertEqual(len(dagruns), 1)
- self.assertTrue(dagruns[0].conf, {"foo": TEST_DAG_ID})
+ assert len(dagruns) == 1
+ assert dagruns[0].conf, {"foo": TEST_DAG_ID}
def test_trigger_dagrun_with_reset_dag_run_false(self):
"""Test TriggerDagRunOperator with reset_dag_run."""
@@ -189,7 +191,7 @@ def test_trigger_dagrun_with_reset_dag_run_false(self):
)
task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True)
- with self.assertRaises(DagRunAlreadyExists):
+ with pytest.raises(DagRunAlreadyExists):
task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True)
def test_trigger_dagrun_with_reset_dag_run_true(self):
@@ -207,8 +209,8 @@ def test_trigger_dagrun_with_reset_dag_run_true(self):
with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
- self.assertEqual(len(dagruns), 1)
- self.assertTrue(dagruns[0].external_trigger)
+ assert len(dagruns) == 1
+ assert dagruns[0].external_trigger
def test_trigger_dagrun_with_wait_for_completion_true(self):
"""Test TriggerDagRunOperator with wait_for_completion."""
@@ -226,7 +228,7 @@ def test_trigger_dagrun_with_wait_for_completion_true(self):
with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
- self.assertEqual(len(dagruns), 1)
+ assert len(dagruns) == 1
def test_trigger_dagrun_with_wait_for_completion_true_fail(self):
"""Test TriggerDagRunOperator with wait_for_completion but triggered dag fails."""
@@ -240,5 +242,5 @@ def test_trigger_dagrun_with_wait_for_completion_true_fail(self):
failed_states=[State.RUNNING],
dag=self.dag,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
task.run(start_date=execution_date, end_date=execution_date)
diff --git a/tests/plugins/test_plugin_ignore.py b/tests/plugins/test_plugin_ignore.py
index 639efc285f426..451745eb7a169 100644
--- a/tests/plugins/test_plugin_ignore.py
+++ b/tests/plugins/test_plugin_ignore.py
@@ -92,5 +92,5 @@ def test_find_not_should_ignore_path(self):
if file_ext != '.py':
continue
detected_files.add(os.path.basename(file_path))
- self.assertEqual(detected_files, should_not_ignore_files)
- self.assertEqual(detected_files & should_ignore_files, set())
+ assert detected_files == should_not_ignore_files
+ assert detected_files & should_ignore_files == set()
diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py
index c9d165ef88068..0dcdd534fc37e 100644
--- a/tests/plugins/test_plugins_manager.py
+++ b/tests/plugins/test_plugins_manager.py
@@ -43,7 +43,7 @@ def test_flaskappbuilder_views(self):
view for view in self.appbuilder.baseviews if view.blueprint.name == appbuilder_class_name
]
- self.assertTrue(len(plugin_views) == 1)
+ assert len(plugin_views) == 1
# view should have a menu item matching category of v_appbuilder_package
links = [
@@ -52,12 +52,12 @@ def test_flaskappbuilder_views(self):
if menu_item.name == v_appbuilder_package['category']
]
- self.assertTrue(len(links) == 1)
+ assert len(links) == 1
# menu link should also have a link matching the name of the package.
link = links[0]
- self.assertEqual(link.name, v_appbuilder_package['category'])
- self.assertEqual(link.childs[0].name, v_appbuilder_package['name'])
+ assert link.name == v_appbuilder_package['category']
+ assert link.childs[0].name == v_appbuilder_package['name']
def test_flaskappbuilder_nomenu_views(self):
from tests.plugins.test_plugin import v_nomenu_appbuilder_package
@@ -74,7 +74,7 @@ class AirflowNoMenuViewsPlugin(AirflowPlugin):
view for view in appbuilder.baseviews if view.blueprint.name == appbuilder_class_name
]
- self.assertTrue(len(plugin_views) == 1)
+ assert len(plugin_views) == 1
def test_flaskappbuilder_menu_links(self):
from tests.plugins.test_plugin import appbuilder_mitem
@@ -86,19 +86,19 @@ def test_flaskappbuilder_menu_links(self):
if menu_item.name == appbuilder_mitem['category']
]
- self.assertTrue(len(links) == 1)
+ assert len(links) == 1
# menu link should also have a link matching the name of the package.
link = links[0]
- self.assertEqual(link.name, appbuilder_mitem['category'])
- self.assertEqual(link.childs[0].name, appbuilder_mitem['name'])
+ assert link.name == appbuilder_mitem['category']
+ assert link.childs[0].name == appbuilder_mitem['name']
def test_app_blueprints(self):
from tests.plugins.test_plugin import bp
# Blueprint should be present in the app
- self.assertTrue('test_plugin' in self.app.blueprints)
- self.assertEqual(self.app.blueprints['test_plugin'].name, bp.name)
+ assert 'test_plugin' in self.app.blueprints
+ assert self.app.blueprints['test_plugin'].name == bp.name
class TestPluginsManager:
@@ -281,9 +281,9 @@ def test_should_return_correct_path_name(self):
from airflow import plugins_manager
source = plugins_manager.PluginsDirectorySource(__file__)
- self.assertEqual("test_plugins_manager.py", source.path)
- self.assertEqual("$PLUGINS_FOLDER/test_plugins_manager.py", str(source))
- self.assertEqual("$PLUGINS_FOLDER/test_plugins_manager.py", source.__html__())
+ assert "test_plugins_manager.py" == source.path
+ assert "$PLUGINS_FOLDER/test_plugins_manager.py" == str(source)
+ assert "$PLUGINS_FOLDER/test_plugins_manager.py" == source.__html__()
class TestEntryPointSource:
diff --git a/tests/providers/amazon/aws/hooks/test_athena.py b/tests/providers/amazon/aws/hooks/test_athena.py
index a9bd757464d6d..5289dd940e49f 100644
--- a/tests/providers/amazon/aws/hooks/test_athena.py
+++ b/tests/providers/amazon/aws/hooks/test_athena.py
@@ -45,8 +45,8 @@ def setUp(self):
self.athena = AWSAthenaHook(sleep_time=0)
def test_init(self):
- self.assertEqual(self.athena.aws_conn_id, 'aws_default')
- self.assertEqual(self.athena.sleep_time, 0)
+ assert self.athena.aws_conn_id == 'aws_default'
+ assert self.athena.sleep_time == 0
@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_run_query_without_token(self, mock_conn):
@@ -63,7 +63,7 @@ def test_hook_run_query_without_token(self, mock_conn):
'WorkGroup': MOCK_DATA['workgroup'],
}
mock_conn.return_value.start_query_execution.assert_called_with(**expected_call_params)
- self.assertEqual(result, MOCK_DATA['query_execution_id'])
+ assert result == MOCK_DATA['query_execution_id']
@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_run_query_with_token(self, mock_conn):
@@ -82,13 +82,13 @@ def test_hook_run_query_with_token(self, mock_conn):
'WorkGroup': MOCK_DATA['workgroup'],
}
mock_conn.return_value.start_query_execution.assert_called_with(**expected_call_params)
- self.assertEqual(result, MOCK_DATA['query_execution_id'])
+ assert result == MOCK_DATA['query_execution_id']
@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_get_query_results_with_non_succeeded_query(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
result = self.athena.get_query_results(query_execution_id=MOCK_DATA['query_execution_id'])
- self.assertIsNone(result)
+ assert result is None
@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_get_query_results_with_default_params(self, mock_conn):
@@ -114,7 +114,7 @@ def test_hook_get_query_results_with_next_token(self, mock_conn):
def test_hook_get_paginator_with_non_succeeded_query(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
result = self.athena.get_query_results_paginator(query_execution_id=MOCK_DATA['query_execution_id'])
- self.assertIsNone(result)
+ assert result is None
@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_get_paginator_with_default_params(self, mock_conn):
@@ -150,7 +150,7 @@ def test_hook_poll_query_when_final(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
result = self.athena.poll_query_status(query_execution_id=MOCK_DATA['query_execution_id'])
mock_conn.return_value.get_query_execution.assert_called_once()
- self.assertEqual(result, 'SUCCEEDED')
+ assert result == 'SUCCEEDED'
@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_poll_query_with_timeout(self, mock_conn):
@@ -159,7 +159,7 @@ def test_hook_poll_query_with_timeout(self, mock_conn):
query_execution_id=MOCK_DATA['query_execution_id'], max_tries=1
)
mock_conn.return_value.get_query_execution.assert_called_once()
- self.assertEqual(result, 'RUNNING')
+ assert result == 'RUNNING'
if __name__ == '__main__':
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 644bd67afdece..da8f8c8a4119f 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -45,7 +45,7 @@ def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr')
client_from_hook = hook.get_client_type('emr')
- self.assertEqual(client_from_hook.list_clusters()['Clusters'], [])
+ assert client_from_hook.list_clusters()['Clusters'] == []
@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
@mock_dynamodb2
@@ -65,7 +65,7 @@ def test_get_resource_type_returns_a_boto3_resource_of_the_requested_type(self):
table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow')
- self.assertEqual(table.item_count, 0)
+ assert table.item_count == 0
@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
@mock_dynamodb2
@@ -84,7 +84,7 @@ def test_get_session_returns_a_boto3_session(self):
table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow')
- self.assertEqual(table.item_count, 0)
+ assert table.item_count == 0
@mock.patch.object(AwsBaseHook, 'get_connection')
def test_get_credentials_from_login_with_token(self, mock_get_connection):
@@ -96,9 +96,9 @@ def test_get_credentials_from_login_with_token(self, mock_get_connection):
mock_get_connection.return_value = mock_connection
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
credentials_from_hook = hook.get_credentials()
- self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
- self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key')
- self.assertEqual(credentials_from_hook.token, 'test_token')
+ assert credentials_from_hook.access_key == 'aws_access_key_id'
+ assert credentials_from_hook.secret_key == 'aws_secret_access_key'
+ assert credentials_from_hook.token == 'test_token'
@mock.patch.object(AwsBaseHook, 'get_connection')
def test_get_credentials_from_login_without_token(self, mock_get_connection):
@@ -110,9 +110,9 @@ def test_get_credentials_from_login_without_token(self, mock_get_connection):
mock_get_connection.return_value = mock_connection
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='spam')
credentials_from_hook = hook.get_credentials()
- self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
- self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key')
- self.assertIsNone(credentials_from_hook.token)
+ assert credentials_from_hook.access_key == 'aws_access_key_id'
+ assert credentials_from_hook.secret_key == 'aws_secret_access_key'
+ assert credentials_from_hook.token is None
@mock.patch.object(AwsBaseHook, 'get_connection')
def test_get_credentials_from_extra_with_token(self, mock_get_connection):
@@ -124,9 +124,9 @@ def test_get_credentials_from_extra_with_token(self, mock_get_connection):
mock_get_connection.return_value = mock_connection
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
credentials_from_hook = hook.get_credentials()
- self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
- self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key')
- self.assertEqual(credentials_from_hook.token, 'session_token')
+ assert credentials_from_hook.access_key == 'aws_access_key_id'
+ assert credentials_from_hook.secret_key == 'aws_secret_access_key'
+ assert credentials_from_hook.token == 'session_token'
@mock.patch.object(AwsBaseHook, 'get_connection')
def test_get_credentials_from_extra_without_token(self, mock_get_connection):
@@ -137,9 +137,9 @@ def test_get_credentials_from_extra_without_token(self, mock_get_connection):
mock_get_connection.return_value = mock_connection
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
credentials_from_hook = hook.get_credentials()
- self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
- self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key')
- self.assertIsNone(credentials_from_hook.token)
+ assert credentials_from_hook.access_key == 'aws_access_key_id'
+ assert credentials_from_hook.secret_key == 'aws_secret_access_key'
+ assert credentials_from_hook.token is None
@mock.patch(
'airflow.providers.amazon.aws.hooks.base_aws._parse_s3_config',
@@ -168,13 +168,13 @@ def test_get_credentials_from_role_arn(self, mock_get_connection):
mock_get_connection.return_value = mock_connection
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
credentials_from_hook = hook.get_credentials()
- self.assertIn("ASIA", credentials_from_hook.access_key)
+ assert "ASIA" in credentials_from_hook.access_key
# We assert the length instead of actual values as the values are random:
# Details: https://github.com/spulec/moto/commit/ab0d23a0ba2506e6338ae20b3fde70da049f7b03
- self.assertEqual(20, len(credentials_from_hook.access_key))
- self.assertEqual(40, len(credentials_from_hook.secret_key))
- self.assertEqual(356, len(credentials_from_hook.token))
+ assert 20 == len(credentials_from_hook.access_key)
+ assert 40 == len(credentials_from_hook.secret_key)
+ assert 356 == len(credentials_from_hook.token)
def test_get_credentials_from_gcp_credentials(self):
mock_connection = Connection(
@@ -208,9 +208,8 @@ def import_mock(name, *args):
credentials_from_hook = hook.get_credentials()
mock_get_credentials = mock_boto3.session.Session.return_value.get_credentials
- self.assertEqual(
- mock_get_credentials.return_value.get_frozen_credentials.return_value,
- credentials_from_hook,
+ assert (
+ mock_get_credentials.return_value.get_frozen_credentials.return_value == credentials_from_hook
)
mock_boto3.assert_has_calls(
@@ -260,7 +259,7 @@ def test_expand_role(self):
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
arn = hook.expand_role('test-role')
expect_arn = conn.get_role(RoleName='test-role').get('Role').get('Arn')
- self.assertEqual(arn, expect_arn)
+ assert arn == expect_arn
def test_use_default_boto3_behaviour_without_conn_id(self):
for conn_id in (None, ''):
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws_system.py b/tests/providers/amazon/aws/hooks/test_base_aws_system.py
index 50bb00cb3524d..cb9b674eee4f7 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws_system.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws_system.py
@@ -52,4 +52,4 @@ def test_run_example_gcp_vision_autogenerated_id_dag(self):
client = hook.get_conn()
response = client.list_buckets()
- self.assertIn('Buckets', response)
+ assert 'Buckets' in response
diff --git a/tests/providers/amazon/aws/hooks/test_batch_client.py b/tests/providers/amazon/aws/hooks/test_batch_client.py
index 07e97d9147377..6f331fd8d2902 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_client.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_client.py
@@ -22,6 +22,7 @@
from unittest import mock
import botocore.exceptions
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -53,7 +54,7 @@ def setUp(self, get_client_type_mock):
region_name=AWS_REGION,
)
self.client_mock = get_client_type_mock.return_value
- self.assertEqual(self.batch_client.client, self.client_mock) # setup client property
+ assert self.batch_client.client == self.client_mock # setup client property
# don't pause in these unit tests
self.mock_delay = mock.Mock(return_value=None)
@@ -62,11 +63,11 @@ def setUp(self, get_client_type_mock):
self.batch_client.exponential_delay = self.mock_exponential_delay
def test_init(self):
- self.assertEqual(self.batch_client.max_retries, self.MAX_RETRIES)
- self.assertEqual(self.batch_client.status_retries, self.STATUS_RETRIES)
- self.assertEqual(self.batch_client.region_name, AWS_REGION)
- self.assertEqual(self.batch_client.aws_conn_id, 'airflow_test')
- self.assertEqual(self.batch_client.client, self.client_mock)
+ assert self.batch_client.max_retries == self.MAX_RETRIES
+ assert self.batch_client.status_retries == self.STATUS_RETRIES
+ assert self.batch_client.region_name == AWS_REGION
+ assert self.batch_client.aws_conn_id == 'airflow_test'
+ assert self.batch_client.client == self.client_mock
self.get_client_type_mock.assert_called_once_with("batch", region_name=AWS_REGION)
@@ -89,7 +90,7 @@ def test_wait_for_job_with_success(self):
self.batch_client.wait_for_job(JOB_ID)
job_complete.assert_called_once_with(JOB_ID, None)
- self.assertEqual(self.client_mock.describe_jobs.call_count, 4)
+ assert self.client_mock.describe_jobs.call_count == 4
def test_wait_for_job_with_failure(self):
self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": "FAILED"}]}
@@ -110,7 +111,7 @@ def test_wait_for_job_with_failure(self):
self.batch_client.wait_for_job(JOB_ID)
job_complete.assert_called_once_with(JOB_ID, None)
- self.assertEqual(self.client_mock.describe_jobs.call_count, 4)
+ assert self.client_mock.describe_jobs.call_count == 4
def test_poll_job_running_for_status_running(self):
self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": "RUNNING"}]}
@@ -124,42 +125,42 @@ def test_poll_job_complete_for_status_success(self):
def test_poll_job_complete_raises_for_max_retries(self):
self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": "RUNNING"}]}
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.batch_client.poll_for_job_complete(JOB_ID)
msg = f"AWS Batch job ({JOB_ID}) status checks exceed max_retries"
- self.assertIn(msg, str(e.exception))
+ assert msg in str(ctx.value)
self.client_mock.describe_jobs.assert_called_with(jobs=[JOB_ID])
- self.assertEqual(self.client_mock.describe_jobs.call_count, self.MAX_RETRIES + 1)
+ assert self.client_mock.describe_jobs.call_count == self.MAX_RETRIES + 1
def test_poll_job_status_hit_api_throttle(self):
self.client_mock.describe_jobs.side_effect = botocore.exceptions.ClientError(
error_response={"Error": {"Code": "TooManyRequestsException"}},
operation_name="get job description",
)
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.batch_client.poll_for_job_complete(JOB_ID)
msg = f"AWS Batch job ({JOB_ID}) description error"
- self.assertIn(msg, str(e.exception))
+ assert msg in str(ctx.value)
# It should retry when this client error occurs
self.client_mock.describe_jobs.assert_called_with(jobs=[JOB_ID])
- self.assertEqual(self.client_mock.describe_jobs.call_count, self.STATUS_RETRIES)
+ assert self.client_mock.describe_jobs.call_count == self.STATUS_RETRIES
def test_poll_job_status_with_client_error(self):
self.client_mock.describe_jobs.side_effect = botocore.exceptions.ClientError(
error_response={"Error": {"Code": "InvalidClientTokenId"}},
operation_name="get job description",
)
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.batch_client.poll_for_job_complete(JOB_ID)
msg = f"AWS Batch job ({JOB_ID}) description error"
- self.assertIn(msg, str(e.exception))
+ assert msg in str(ctx.value)
# It will not retry when this client error occurs
self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])
def test_check_job_success(self):
self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}]}
status = self.batch_client.check_job_success(JOB_ID)
- self.assertTrue(status)
+ assert status
self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])
def test_check_job_success_raises_failed(self):
@@ -173,11 +174,11 @@ def test_check_job_success_raises_failed(self):
}
]
}
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.batch_client.check_job_success(JOB_ID)
self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])
msg = f"AWS Batch job ({JOB_ID}) failed"
- self.assertIn(msg, str(e.exception))
+ assert msg in str(ctx.value)
def test_check_job_success_raises_failed_for_multiple_attempts(self):
self.client_mock.describe_jobs.return_value = {
@@ -190,44 +191,44 @@ def test_check_job_success_raises_failed_for_multiple_attempts(self):
}
]
}
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.batch_client.check_job_success(JOB_ID)
self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])
msg = f"AWS Batch job ({JOB_ID}) failed"
- self.assertIn(msg, str(e.exception))
+ assert msg in str(ctx.value)
def test_check_job_success_raises_incomplete(self):
self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": "RUNNABLE"}]}
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.batch_client.check_job_success(JOB_ID)
self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])
msg = f"AWS Batch job ({JOB_ID}) is not complete"
- self.assertIn(msg, str(e.exception))
+ assert msg in str(ctx.value)
def test_check_job_success_raises_unknown_status(self):
status = "STRANGE"
self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": status}]}
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.batch_client.check_job_success(JOB_ID)
self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])
msg = f"AWS Batch job ({JOB_ID}) has unknown status"
- self.assertIn(msg, str(e.exception))
- self.assertIn(status, str(e.exception))
+ assert msg in str(ctx.value)
+ assert status in str(ctx.value)
def test_check_job_success_raises_without_jobs(self):
self.client_mock.describe_jobs.return_value = {"jobs": []}
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.batch_client.check_job_success(JOB_ID)
self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])
msg = f"AWS Batch job ({JOB_ID}) description error"
- self.assertIn(msg, str(e.exception))
+ assert msg in str(ctx.value)
def test_terminate_job(self):
self.client_mock.terminate_job.return_value = {}
reason = "Task killed by the user"
response = self.batch_client.terminate_job(JOB_ID, reason)
self.client_mock.terminate_job.assert_called_once_with(jobId=JOB_ID, reason=reason)
- self.assertEqual(response, {})
+ assert response == {}
class TestAwsBatchClientDelays(unittest.TestCase):
@@ -238,23 +239,23 @@ def setUp(self):
self.batch_client = AwsBatchClientHook(aws_conn_id='airflow_test', region_name=AWS_REGION)
def test_init(self):
- self.assertEqual(self.batch_client.max_retries, self.batch_client.MAX_RETRIES)
- self.assertEqual(self.batch_client.status_retries, self.batch_client.STATUS_RETRIES)
- self.assertEqual(self.batch_client.region_name, AWS_REGION)
- self.assertEqual(self.batch_client.aws_conn_id, 'airflow_test')
+ assert self.batch_client.max_retries == self.batch_client.MAX_RETRIES
+ assert self.batch_client.status_retries == self.batch_client.STATUS_RETRIES
+ assert self.batch_client.region_name == AWS_REGION
+ assert self.batch_client.aws_conn_id == 'airflow_test'
def test_add_jitter(self):
minima = 0
width = 5
result = self.batch_client.add_jitter(0, width=width, minima=minima)
- self.assertGreaterEqual(result, minima)
- self.assertLessEqual(result, width)
+ assert result >= minima
+ assert result <= width
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.uniform")
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.sleep")
def test_delay_defaults(self, mock_sleep, mock_uniform):
- self.assertEqual(AwsBatchClientHook.DEFAULT_DELAY_MIN, 1)
- self.assertEqual(AwsBatchClientHook.DEFAULT_DELAY_MAX, 10)
+ assert AwsBatchClientHook.DEFAULT_DELAY_MIN == 1
+ assert AwsBatchClientHook.DEFAULT_DELAY_MAX == 10
mock_uniform.return_value = 0
self.batch_client.delay()
mock_uniform.assert_called_once_with(
@@ -300,5 +301,5 @@ def test_delay_with_float(self, mock_sleep, mock_uniform):
)
def test_exponential_delay(self, tries, lower, upper):
result = self.batch_client.exponential_delay(tries)
- self.assertGreaterEqual(result, lower)
- self.assertLessEqual(result, upper)
+ assert result >= lower
+ assert result <= upper
diff --git a/tests/providers/amazon/aws/hooks/test_batch_waiters.py b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
index 400c3bedccfe6..51d42e0156437 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_waiters.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=do-not-use-asserts, missing-docstring, redefined-outer-name
+# pylint: disable=missing-docstring, redefined-outer-name
"""
@@ -265,12 +265,12 @@ def test_aws_batch_job_waiting(aws_clients, aws_region, job_queue_name, job_defi
assert job_complete_waiter.__class__.__name__ == "Batch.Waiter.JobComplete"
# test waiting on a jobId that does not exist (this throws immediately)
- with pytest.raises(botocore.exceptions.WaiterError) as err:
+ with pytest.raises(botocore.exceptions.WaiterError) as ctx:
job_exists_waiter.config.delay = 0.2
job_exists_waiter.config.max_attempts = 2
job_exists_waiter.wait(jobs=["missing-job"])
- assert isinstance(err.value, botocore.exceptions.WaiterError)
- assert "Waiter JobExists failed" in str(err.value)
+ assert isinstance(ctx.value, botocore.exceptions.WaiterError)
+ assert "Waiter JobExists failed" in str(ctx.value)
# Submit a job and wait for various job status indicators;
# moto transitions the batch job status automatically.
@@ -300,10 +300,10 @@ def test_aws_batch_job_waiting(aws_clients, aws_region, job_queue_name, job_defi
# test waiting for job completion with too few attempts (possibly before job is running)
job_complete_waiter.config.delay = 0.1
job_complete_waiter.config.max_attempts = 1
- with pytest.raises(botocore.exceptions.WaiterError) as err:
+ with pytest.raises(botocore.exceptions.WaiterError) as ctx:
job_complete_waiter.wait(jobs=[job_id])
- assert isinstance(err.value, botocore.exceptions.WaiterError)
- assert "Waiter JobComplete failed: Max attempts exceeded" in str(err.value)
+ assert isinstance(ctx.value, botocore.exceptions.WaiterError)
+ assert "Waiter JobComplete failed: Max attempts exceeded" in str(ctx.value)
# wait for job to be running (or complete)
job_running_waiter.config.delay = 0.25 # sec delays between status checks
@@ -320,9 +320,6 @@ def test_aws_batch_job_waiting(aws_clients, aws_region, job_queue_name, job_defi
assert job_status == "SUCCEEDED"
-# pylint: enable=do-not-use-asserts
-
-
class TestAwsBatchWaiters(unittest.TestCase):
@mock.patch.dict("os.environ", AWS_DEFAULT_REGION=AWS_REGION)
@mock.patch.dict("os.environ", AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID)
@@ -333,8 +330,8 @@ def setUp(self, get_client_type_mock):
self.region_name = AWS_REGION
self.batch_waiters = AwsBatchWaitersHook(region_name=self.region_name)
- self.assertEqual(self.batch_waiters.aws_conn_id, 'aws_default')
- self.assertEqual(self.batch_waiters.region_name, self.region_name)
+ assert self.batch_waiters.aws_conn_id == 'aws_default'
+ assert self.batch_waiters.region_name == self.region_name
# init the mock client
self.client_mock = self.batch_waiters.client
@@ -349,48 +346,48 @@ def setUp(self, get_client_type_mock):
def test_default_config(self):
# the default config is used when no custom config is provided
config = self.batch_waiters.default_config
- self.assertEqual(config, self.batch_waiters.waiter_config)
+ assert config == self.batch_waiters.waiter_config
- self.assertIsInstance(config, dict)
- self.assertEqual(config["version"], 2)
- self.assertIsInstance(config["waiters"], dict)
+ assert isinstance(config, dict)
+ assert config["version"] == 2
+ assert isinstance(config["waiters"], dict)
waiters = list(sorted(config["waiters"].keys()))
- self.assertEqual(waiters, ["JobComplete", "JobExists", "JobRunning"])
+ assert waiters == ["JobComplete", "JobExists", "JobRunning"]
def test_list_waiters(self):
# the default config is used when no custom config is provided
config = self.batch_waiters.waiter_config
- self.assertIsInstance(config["waiters"], dict)
+ assert isinstance(config["waiters"], dict)
waiters = list(sorted(config["waiters"].keys()))
- self.assertEqual(waiters, ["JobComplete", "JobExists", "JobRunning"])
- self.assertEqual(waiters, self.batch_waiters.list_waiters())
+ assert waiters == ["JobComplete", "JobExists", "JobRunning"]
+ assert waiters == self.batch_waiters.list_waiters()
def test_waiter_model(self):
model = self.batch_waiters.waiter_model
- self.assertIsInstance(model, botocore.waiter.WaiterModel)
+ assert isinstance(model, botocore.waiter.WaiterModel)
# test some of the default config
- self.assertEqual(model.version, 2)
+ assert model.version == 2
waiters = sorted(model.waiter_names)
- self.assertEqual(waiters, ["JobComplete", "JobExists", "JobRunning"])
+ assert waiters == ["JobComplete", "JobExists", "JobRunning"]
# test errors when requesting a waiter with the wrong name
- with self.assertRaises(ValueError) as e:
+ with pytest.raises(ValueError) as ctx:
model.get_waiter("JobExist")
- self.assertIn("Waiter does not exist: JobExist", str(e.exception))
+ assert "Waiter does not exist: JobExist" in str(ctx.value)
# test some default waiter properties
waiter = model.get_waiter("JobExists")
- self.assertIsInstance(waiter, botocore.waiter.SingleWaiterConfig)
- self.assertEqual(waiter.max_attempts, 100)
+ assert isinstance(waiter, botocore.waiter.SingleWaiterConfig)
+ assert waiter.max_attempts == 100
waiter.max_attempts = 200
- self.assertEqual(waiter.max_attempts, 200)
- self.assertEqual(waiter.delay, 2)
+ assert waiter.max_attempts == 200
+ assert waiter.delay == 2
waiter.delay = 10
- self.assertEqual(waiter.delay, 10)
- self.assertEqual(waiter.operation, "DescribeJobs")
+ assert waiter.delay == 10
+ assert waiter.operation == "DescribeJobs"
def test_wait_for_job(self):
import sys
@@ -403,18 +400,19 @@ def test_wait_for_job(self):
self.batch_waiters.wait_for_job(self.job_id)
- self.assertEqual(
- get_waiter.call_args_list,
- [mock.call("JobExists"), mock.call("JobRunning"), mock.call("JobComplete")],
- )
+ assert get_waiter.call_args_list == [
+ mock.call("JobExists"),
+ mock.call("JobRunning"),
+ mock.call("JobComplete"),
+ ]
mock_waiter = get_waiter.return_value
mock_waiter.wait.assert_called_with(jobs=[self.job_id])
- self.assertEqual(mock_waiter.wait.call_count, 3)
+ assert mock_waiter.wait.call_count == 3
mock_config = mock_waiter.config
- self.assertEqual(mock_config.delay, 0)
- self.assertEqual(mock_config.max_attempts, sys.maxsize)
+ assert mock_config.delay == 0
+ assert mock_config.max_attempts == sys.maxsize
def test_wait_for_job_raises_for_client_error(self):
# mock delay for speedy test
@@ -427,12 +425,12 @@ def test_wait_for_job_raises_for_client_error(self):
error_response={"Error": {"Code": "TooManyRequestsException"}},
operation_name="get job description",
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.batch_waiters.wait_for_job(self.job_id)
- self.assertEqual(get_waiter.call_args_list, [mock.call("JobExists")])
+ assert get_waiter.call_args_list == [mock.call("JobExists")]
mock_waiter.wait.assert_called_with(jobs=[self.job_id])
- self.assertEqual(mock_waiter.wait.call_count, 1)
+ assert mock_waiter.wait.call_count == 1
def test_wait_for_job_raises_for_waiter_error(self):
# mock delay for speedy test
@@ -444,9 +442,9 @@ def test_wait_for_job_raises_for_waiter_error(self):
mock_waiter.wait.side_effect = botocore.exceptions.WaiterError(
name="JobExists", reason="unit test error", last_response={}
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.batch_waiters.wait_for_job(self.job_id)
- self.assertEqual(get_waiter.call_args_list, [mock.call("JobExists")])
+ assert get_waiter.call_args_list == [mock.call("JobExists")]
mock_waiter.wait.assert_called_with(jobs=[self.job_id])
- self.assertEqual(mock_waiter.wait.call_count, 1)
+ assert mock_waiter.wait.call_count == 1
diff --git a/tests/providers/amazon/aws/hooks/test_cloud_formation.py b/tests/providers/amazon/aws/hooks/test_cloud_formation.py
index f7d77d08a115b..09e0bb8cd9b6c 100644
--- a/tests/providers/amazon/aws/hooks/test_cloud_formation.py
+++ b/tests/providers/amazon/aws/hooks/test_cloud_formation.py
@@ -57,18 +57,18 @@ def create_stack(self, stack_name):
@mock_cloudformation
def test_get_conn_returns_a_boto3_connection(self):
- self.assertIsNotNone(self.hook.get_conn().describe_stacks())
+ assert self.hook.get_conn().describe_stacks() is not None
@mock_cloudformation
def test_get_stack_status(self):
stack_name = 'my_test_get_stack_status_stack'
stack_status = self.hook.get_stack_status(stack_name=stack_name)
- self.assertIsNone(stack_status)
+ assert stack_status is None
self.create_stack(stack_name)
stack_status = self.hook.get_stack_status(stack_name=stack_name)
- self.assertEqual(stack_status, 'CREATE_COMPLETE', 'Incorrect stack status returned.')
+ assert stack_status == 'CREATE_COMPLETE', 'Incorrect stack status returned.'
@mock_cloudformation
def test_create_stack(self):
@@ -76,13 +76,13 @@ def test_create_stack(self):
self.create_stack(stack_name)
stacks = self.hook.get_conn().describe_stacks()['Stacks']
- self.assertGreater(len(stacks), 0, 'CloudFormation should have stacks')
+ assert len(stacks) > 0, 'CloudFormation should have stacks'
matching_stacks = [x for x in stacks if x['StackName'] == stack_name]
- self.assertEqual(len(matching_stacks), 1, f'stack with name {stack_name} should exist')
+ assert len(matching_stacks) == 1, f'stack with name {stack_name} should exist'
stack = matching_stacks[0]
- self.assertEqual(stack['StackStatus'], 'CREATE_COMPLETE', 'Stack should be in status CREATE_COMPLETE')
+ assert stack['StackStatus'] == 'CREATE_COMPLETE', 'Stack should be in status CREATE_COMPLETE'
@mock_cloudformation
def test_delete_stack(self):
@@ -93,4 +93,4 @@ def test_delete_stack(self):
stacks = self.hook.get_conn().describe_stacks()['Stacks']
matching_stacks = [x for x in stacks if x['StackName'] == stack_name]
- self.assertEqual(len(matching_stacks), 0, f'stack with name {stack_name} should not exist')
+ assert len(matching_stacks) == 0, f'stack with name {stack_name} should not exist'
diff --git a/tests/providers/amazon/aws/hooks/test_datasync.py b/tests/providers/amazon/aws/hooks/test_datasync.py
index 94d420fc957c9..eccdd5fa61ccc 100644
--- a/tests/providers/amazon/aws/hooks/test_datasync.py
+++ b/tests/providers/amazon/aws/hooks/test_datasync.py
@@ -20,6 +20,7 @@
from unittest import mock
import boto3
+import pytest
from moto import mock_datasync
from airflow.exceptions import AirflowTaskTimeout
@@ -30,7 +31,7 @@
class TestAwsDataSyncHook(unittest.TestCase):
def test_get_conn(self):
hook = AWSDataSyncHook(aws_conn_id="aws_default")
- self.assertIsNotNone(hook.get_conn())
+ assert hook.get_conn() is not None
# Explanation of: @mock.patch.object(AWSDataSyncHook, 'get_conn')
@@ -100,9 +101,9 @@ def test_init(self, mock_get_conn):
mock_get_conn.return_value = self.client
# ### Begin tests:
- self.assertFalse(self.hook.locations)
- self.assertFalse(self.hook.tasks)
- self.assertEqual(self.hook.wait_interval_seconds, 0)
+ assert not self.hook.locations
+ assert not self.hook.tasks
+ assert self.hook.wait_interval_seconds == 0
def test_create_location_smb(self, mock_get_conn):
# ### Configure mock:
@@ -110,7 +111,7 @@ def test_create_location_smb(self, mock_get_conn):
# ### Begin tests:
locations = self.hook.get_conn().list_locations()
- self.assertEqual(len(locations["Locations"]), 2)
+ assert len(locations["Locations"]) == 2
server_hostname = "my.hostname"
subdirectory = "my_dir"
@@ -131,18 +132,18 @@ def test_create_location_smb(self, mock_get_conn):
"MountOptions": mount_options,
}
location_arn = self.hook.create_location(location_uri, **create_location_kwargs)
- self.assertIsNotNone(location_arn)
+ assert location_arn is not None
locations = self.client.list_locations()
- self.assertEqual(len(locations["Locations"]), 3)
+ assert len(locations["Locations"]) == 3
location_desc = self.client.describe_location_smb(LocationArn=location_arn)
- self.assertEqual(location_desc["LocationArn"], location_arn)
- self.assertEqual(location_desc["LocationUri"], location_uri)
- self.assertEqual(location_desc["AgentArns"], agent_arns)
- self.assertEqual(location_desc["User"], user)
- self.assertEqual(location_desc["Domain"], domain)
- self.assertEqual(location_desc["MountOptions"], mount_options)
+ assert location_desc["LocationArn"] == location_arn
+ assert location_desc["LocationUri"] == location_uri
+ assert location_desc["AgentArns"] == agent_arns
+ assert location_desc["User"] == user
+ assert location_desc["Domain"] == domain
+ assert location_desc["MountOptions"] == mount_options
def test_create_location_s3(self, mock_get_conn):
# ### Configure mock:
@@ -150,7 +151,7 @@ def test_create_location_s3(self, mock_get_conn):
# ### Begin tests:
locations = self.hook.get_conn().list_locations()
- self.assertEqual(len(locations["Locations"]), 2)
+ assert len(locations["Locations"]) == 2
s3_bucket_arn = "some_s3_arn"
subdirectory = "my_subdir"
@@ -164,15 +165,15 @@ def test_create_location_s3(self, mock_get_conn):
"S3Config": s3_config,
}
location_arn = self.hook.create_location(location_uri, **create_location_kwargs)
- self.assertIsNotNone(location_arn)
+ assert location_arn is not None
locations = self.client.list_locations()
- self.assertEqual(len(locations["Locations"]), 3)
+ assert len(locations["Locations"]) == 3
location_desc = self.client.describe_location_s3(LocationArn=location_arn)
- self.assertEqual(location_desc["LocationArn"], location_arn)
- self.assertEqual(location_desc["LocationUri"], location_uri)
- self.assertEqual(location_desc["S3Config"], s3_config)
+ assert location_desc["LocationArn"] == location_arn
+ assert location_desc["LocationUri"] == location_uri
+ assert location_desc["S3Config"] == s3_config
def test_create_task(self, mock_get_conn):
# ### Configure mock:
@@ -207,10 +208,10 @@ def test_create_task(self, mock_get_conn):
)
task = self.client.describe_task(TaskArn=task_arn)
- self.assertEqual(task["TaskArn"], task_arn)
- self.assertEqual(task["Name"], name)
- self.assertEqual(task["CloudWatchLogGroupArn"], log_group_arn)
- self.assertEqual(task["Options"], options)
+ assert task["TaskArn"] == task_arn
+ assert task["Name"] == name
+ assert task["CloudWatchLogGroupArn"] == log_group_arn
+ assert task["Options"] == options
def test_update_task(self, mock_get_conn):
# ### Configure mock:
@@ -220,13 +221,13 @@ def test_update_task(self, mock_get_conn):
task_arn = self.task_arn
task = self.client.describe_task(TaskArn=task_arn)
- self.assertNotIn("Name", task)
+ assert "Name" not in task
update_task_kwargs = {"Name": "xyz"}
self.hook.update_task(task_arn, **update_task_kwargs)
task = self.client.describe_task(TaskArn=task_arn)
- self.assertEqual(task["Name"], "xyz")
+ assert task["Name"] == "xyz"
def test_delete_task(self, mock_get_conn):
# ### Configure mock:
@@ -236,12 +237,12 @@ def test_delete_task(self, mock_get_conn):
task_arn = self.task_arn
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks["Tasks"]), 1)
+ assert len(tasks["Tasks"]) == 1
self.hook.delete_task(task_arn)
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks["Tasks"]), 0)
+ assert len(tasks["Tasks"]) == 0
def test_get_location_arns(self, mock_get_conn):
# ### Configure mock:
@@ -258,8 +259,8 @@ def test_get_location_arns(self, mock_get_conn):
# Verify our self.hook gets the same
location_arns = self.hook.get_location_arns(location_uri)
- self.assertEqual(len(location_arns), 1)
- self.assertEqual(location_arns[0], location_arn)
+ assert len(location_arns) == 1
+ assert location_arns[0] == location_arn
def test_get_location_arns_case_sensitive(self, mock_get_conn):
# ### Configure mock:
@@ -275,10 +276,10 @@ def test_get_location_arns_case_sensitive(self, mock_get_conn):
# Verify our self.hook can do case sensitive searches
location_arns = self.hook.get_location_arns(location_uri, case_sensitive=True)
- self.assertEqual(len(location_arns), 0)
+ assert len(location_arns) == 0
location_arns = self.hook.get_location_arns(location_uri, case_sensitive=False)
- self.assertEqual(len(location_arns), 1)
- self.assertEqual(location_arns[0], location_arn)
+ assert len(location_arns) == 1
+ assert location_arns[0] == location_arn
def test_get_location_arns_trailing_slash(self, mock_get_conn):
# ### Configure mock:
@@ -294,10 +295,10 @@ def test_get_location_arns_trailing_slash(self, mock_get_conn):
# Verify our self.hook manages trailing / correctly
location_arns = self.hook.get_location_arns(location_uri, ignore_trailing_slash=False)
- self.assertEqual(len(location_arns), 0)
+ assert len(location_arns) == 0
location_arns = self.hook.get_location_arns(location_uri, ignore_trailing_slash=True)
- self.assertEqual(len(location_arns), 1)
- self.assertEqual(location_arns[0], location_arn)
+ assert len(location_arns) == 1
+ assert location_arns[0] == location_arn
def test_get_task_arns_for_location_arns(self, mock_get_conn):
# ### Configure mock:
@@ -307,11 +308,11 @@ def test_get_task_arns_for_location_arns(self, mock_get_conn):
task_arns = self.hook.get_task_arns_for_location_arns(
[self.source_location_arn], [self.destination_location_arn]
)
- self.assertEqual(len(task_arns), 1)
- self.assertEqual(task_arns[0], self.task_arn)
+ assert len(task_arns) == 1
+ assert task_arns[0] == self.task_arn
task_arns = self.hook.get_task_arns_for_location_arns(["foo"], ["bar"])
- self.assertEqual(len(task_arns), 0)
+ assert len(task_arns) == 0
def test_start_task_execution(self, mock_get_conn):
# ### Configure mock:
@@ -319,17 +320,17 @@ def test_start_task_execution(self, mock_get_conn):
# ### Begin tests:
task = self.client.describe_task(TaskArn=self.task_arn)
- self.assertNotIn("CurrentTaskExecutionArn", task)
+ assert "CurrentTaskExecutionArn" not in task
task_execution_arn = self.hook.start_task_execution(self.task_arn)
- self.assertIsNotNone(task_execution_arn)
+ assert task_execution_arn is not None
task = self.client.describe_task(TaskArn=self.task_arn)
- self.assertIn("CurrentTaskExecutionArn", task)
- self.assertEqual(task["CurrentTaskExecutionArn"], task_execution_arn)
+ assert "CurrentTaskExecutionArn" in task
+ assert task["CurrentTaskExecutionArn"] == task_execution_arn
task_execution = self.client.describe_task_execution(TaskExecutionArn=task_execution_arn)
- self.assertIn("Status", task_execution)
+ assert "Status" in task_execution
def test_cancel_task_execution(self, mock_get_conn):
# ### Configure mock:
@@ -337,12 +338,12 @@ def test_cancel_task_execution(self, mock_get_conn):
# ### Begin tests:
task_execution_arn = self.hook.start_task_execution(self.task_arn)
- self.assertIsNotNone(task_execution_arn)
+ assert task_execution_arn is not None
self.hook.cancel_task_execution(task_execution_arn=task_execution_arn)
task = self.client.describe_task(TaskArn=self.task_arn)
- self.assertNotIn("CurrentTaskExecutionArn", task)
+ assert "CurrentTaskExecutionArn" not in task
def test_get_task_description(self, mock_get_conn):
# ### Configure mock:
@@ -350,11 +351,11 @@ def test_get_task_description(self, mock_get_conn):
# ### Begin tests:
task = self.client.describe_task(TaskArn=self.task_arn)
- self.assertIn("TaskArn", task)
- self.assertIn("Status", task)
- self.assertIn("SourceLocationArn", task)
- self.assertIn("DestinationLocationArn", task)
- self.assertNotIn("CurrentTaskExecutionArn", task)
+ assert "TaskArn" in task
+ assert "Status" in task
+ assert "SourceLocationArn" in task
+ assert "DestinationLocationArn" in task
+ assert "CurrentTaskExecutionArn" not in task
def test_get_current_task_execution_arn(self, mock_get_conn):
# ### Configure mock:
@@ -364,7 +365,7 @@ def test_get_current_task_execution_arn(self, mock_get_conn):
task_execution_arn = self.hook.start_task_execution(self.task_arn)
current_task_execution = self.hook.get_current_task_execution_arn(self.task_arn)
- self.assertEqual(current_task_execution, task_execution_arn)
+ assert current_task_execution == task_execution_arn
def test_wait_for_task_execution(self, mock_get_conn):
# ### Configure mock:
@@ -374,7 +375,7 @@ def test_wait_for_task_execution(self, mock_get_conn):
task_execution_arn = self.hook.start_task_execution(self.task_arn)
result = self.hook.wait_for_task_execution(task_execution_arn, max_iterations=20)
- self.assertIsNotNone(result)
+ assert result is not None
def test_wait_for_task_execution_timeout(self, mock_get_conn):
# ### Configure mock:
@@ -382,6 +383,6 @@ def test_wait_for_task_execution_timeout(self, mock_get_conn):
# ### Begin tests:
task_execution_arn = self.hook.start_task_execution(self.task_arn)
- with self.assertRaises(AirflowTaskTimeout):
+ with pytest.raises(AirflowTaskTimeout):
result = self.hook.wait_for_task_execution(task_execution_arn, max_iterations=1)
- self.assertIsNone(result)
+ assert result is None
diff --git a/tests/providers/amazon/aws/hooks/test_dynamodb.py b/tests/providers/amazon/aws/hooks/test_dynamodb.py
index 7377c28024bd5..a46338ff9670f 100644
--- a/tests/providers/amazon/aws/hooks/test_dynamodb.py
+++ b/tests/providers/amazon/aws/hooks/test_dynamodb.py
@@ -33,7 +33,7 @@ class TestDynamoDBHook(unittest.TestCase):
@mock_dynamodb2
def test_get_conn_returns_a_boto3_connection(self):
hook = AwsDynamoDBHook(aws_conn_id='aws_default')
- self.assertIsNotNone(hook.get_conn())
+ assert hook.get_conn() is not None
@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not present')
@mock_dynamodb2
@@ -60,4 +60,4 @@ def test_insert_batch_items_dynamodb_table(self):
hook.write_batch_data(items)
table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow')
- self.assertEqual(table.item_count, 10)
+ assert table.item_count == 10
diff --git a/tests/providers/amazon/aws/hooks/test_ec2.py b/tests/providers/amazon/aws/hooks/test_ec2.py
index 3e89860142864..fd390b62b8638 100644
--- a/tests/providers/amazon/aws/hooks/test_ec2.py
+++ b/tests/providers/amazon/aws/hooks/test_ec2.py
@@ -30,14 +30,14 @@ def test_init(self):
aws_conn_id="aws_conn_test",
region_name="region-test",
)
- self.assertEqual(ec2_hook.aws_conn_id, "aws_conn_test")
- self.assertEqual(ec2_hook.region_name, "region-test")
+ assert ec2_hook.aws_conn_id == "aws_conn_test"
+ assert ec2_hook.region_name == "region-test"
@mock_ec2
def test_get_conn_returns_boto3_resource(self):
ec2_hook = EC2Hook()
instances = list(ec2_hook.conn.instances.all())
- self.assertIsNotNone(instances)
+ assert instances is not None
@mock_ec2
def test_get_instance(self):
@@ -49,7 +49,7 @@ def test_get_instance(self):
created_instance_id = created_instances[0].instance_id
# test get_instance method
existing_instance = ec2_hook.get_instance(instance_id=created_instance_id)
- self.assertEqual(created_instance_id, existing_instance.instance_id)
+ assert created_instance_id == existing_instance.instance_id
@mock_ec2
def test_get_instance_state(self):
@@ -63,4 +63,4 @@ def test_get_instance_state(self):
created_instance_state = all_instances[0].state["Name"]
# test get_instance_state method
existing_instance_state = ec2_hook.get_instance_state(instance_id=created_instance_id)
- self.assertEqual(created_instance_state, existing_instance_state)
+ assert created_instance_state == existing_instance_state
diff --git a/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py b/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py
index 4e40541498b16..d45d1a8f8d1a2 100644
--- a/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py
+++ b/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py
@@ -19,6 +19,8 @@
from unittest import TestCase
from unittest.mock import Mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.elasticache_replication_group import ElastiCacheReplicationGroupHook
@@ -202,7 +204,7 @@ def test_ensure_delete_replication_group_failure(self):
"ReplicationGroup": {"ReplicationGroupId": self.REPLICATION_GROUP_ID}
}
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
# Try only 1 once with 1 sec buffer time. This will ensure that the `wait_for_deletion` loop
# breaks quickly before the group is deleted and we get the Airflow exception
self.hook.ensure_delete_replication_group(
diff --git a/tests/providers/amazon/aws/hooks/test_emr.py b/tests/providers/amazon/aws/hooks/test_emr.py
index 7d954e88ee0f3..89d4f5023fb43 100644
--- a/tests/providers/amazon/aws/hooks/test_emr.py
+++ b/tests/providers/amazon/aws/hooks/test_emr.py
@@ -34,7 +34,7 @@ class TestEmrHook(unittest.TestCase):
@mock_emr
def test_get_conn_returns_a_boto3_connection(self):
hook = EmrHook(aws_conn_id='aws_default', region_name='ap-southeast-2')
- self.assertIsNotNone(hook.get_conn().list_clusters())
+ assert hook.get_conn().list_clusters() is not None
@mock_emr
def test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self):
@@ -43,7 +43,7 @@ def test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self):
hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default')
cluster = hook.create_job_flow({'Name': 'test_cluster'})
- self.assertEqual(client.list_clusters()['Clusters'][0]['Id'], cluster['JobFlowId'])
+ assert client.list_clusters()['Clusters'][0]['Id'] == cluster['JobFlowId']
@mock_emr
def test_create_job_flow_extra_args(self):
@@ -64,7 +64,7 @@ def test_create_job_flow_extra_args(self):
cluster = client.describe_cluster(ClusterId=cluster['JobFlowId'])['Cluster']
# The AmiVersion comes back as {Requested,Running}AmiVersion fields.
- self.assertEqual(cluster['RequestedAmiVersion'], '3.2')
+ assert cluster['RequestedAmiVersion'] == '3.2'
@mock_emr
def test_get_cluster_id_by_name(self):
@@ -81,8 +81,8 @@ def test_get_cluster_id_by_name(self):
matching_cluster = hook.get_cluster_id_by_name('test_cluster', ['RUNNING', 'WAITING'])
- self.assertEqual(matching_cluster, job_flow_id)
+ assert matching_cluster == job_flow_id
no_match = hook.get_cluster_id_by_name('foo', ['RUNNING', 'WAITING', 'BOOTSTRAPPING'])
- self.assertIsNone(no_match)
+ assert no_match is None
diff --git a/tests/providers/amazon/aws/hooks/test_glacier.py b/tests/providers/amazon/aws/hooks/test_glacier.py
index 0f60f371884f1..c1c86a5d9a7a0 100644
--- a/tests/providers/amazon/aws/hooks/test_glacier.py
+++ b/tests/providers/amazon/aws/hooks/test_glacier.py
@@ -45,7 +45,7 @@ def test_retrieve_inventory_should_return_job_id(self, mock_conn):
result = self.hook.retrieve_inventory(VAULT_NAME)
# then
mock_conn.assert_called_once_with()
- self.assertEqual(job_id, result)
+ assert job_id == result
@mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
def test_retrieve_inventory_should_log_mgs(self, mock_conn):
@@ -81,7 +81,7 @@ def test_retrieve_inventory_results_should_return_response(self, mock_conn):
response = self.hook.retrieve_inventory_results(VAULT_NAME, JOB_ID)
# then
mock_conn.assert_called_once_with()
- self.assertEqual(response, RESPONSE_BODY)
+ assert response == RESPONSE_BODY
@mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
def test_retrieve_inventory_results_should_log_mgs(self, mock_conn):
@@ -105,7 +105,7 @@ def test_describe_job_should_return_status_succeeded(self, mock_conn):
response = self.hook.describe_job(VAULT_NAME, JOB_ID)
# then
mock_conn.assert_called_once_with()
- self.assertEqual(response, JOB_STATUS)
+ assert response == JOB_STATUS
@mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
def test_describe_job_should_log_mgs(self, mock_conn):
diff --git a/tests/providers/amazon/aws/hooks/test_glue.py b/tests/providers/amazon/aws/hooks/test_glue.py
index 512e7c84691cd..0bc57527510b0 100644
--- a/tests/providers/amazon/aws/hooks/test_glue.py
+++ b/tests/providers/amazon/aws/hooks/test_glue.py
@@ -52,7 +52,7 @@ def test_get_iam_execution_role(self):
),
)
iam_role = hook.get_iam_execution_role()
- self.assertIsNotNone(iam_role)
+ assert iam_role is not None
@mock.patch.object(AwsGlueJobHook, "get_iam_execution_role")
@mock.patch.object(AwsGlueJobHook, "get_conn")
@@ -70,7 +70,7 @@ def test_get_or_create_glue_job(self, mock_get_conn, mock_get_iam_execution_role
s3_bucket=some_s3_bucket,
region_name=self.some_aws_region,
).get_or_create_glue_job()
- self.assertEqual(glue_job, mock_glue_job)
+ assert glue_job == mock_glue_job
@mock.patch.object(AwsGlueJobHook, "get_job_state")
@mock.patch.object(AwsGlueJobHook, "get_or_create_glue_job")
@@ -95,7 +95,7 @@ def test_initialize_job(self, mock_get_conn, mock_get_or_create_glue_job, mock_g
)
glue_job_run = glue_job_hook.initialize_job(some_script_arguments)
glue_job_run_state = glue_job_hook.get_job_state(glue_job_run['JobName'], glue_job_run['JobRunId'])
- self.assertEqual(glue_job_run_state, mock_job_run_state, msg='Mocks but be equal')
+ assert glue_job_run_state == mock_job_run_state, 'Mocks but be equal'
if __name__ == '__main__':
diff --git a/tests/providers/amazon/aws/hooks/test_glue_catalog.py b/tests/providers/amazon/aws/hooks/test_glue_catalog.py
index 3c2adc59464ba..689df491d7af9 100644
--- a/tests/providers/amazon/aws/hooks/test_glue_catalog.py
+++ b/tests/providers/amazon/aws/hooks/test_glue_catalog.py
@@ -20,6 +20,7 @@
from unittest import mock
import boto3
+import pytest
from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook
@@ -49,17 +50,17 @@ def setUp(self):
@mock_glue
def test_get_conn_returns_a_boto3_connection(self):
hook = AwsGlueCatalogHook(region_name="us-east-1")
- self.assertIsNotNone(hook.get_conn())
+ assert hook.get_conn() is not None
@mock_glue
def test_conn_id(self):
hook = AwsGlueCatalogHook(aws_conn_id='my_aws_conn_id', region_name="us-east-1")
- self.assertEqual(hook.aws_conn_id, 'my_aws_conn_id')
+ assert hook.aws_conn_id == 'my_aws_conn_id'
@mock_glue
def test_region(self):
hook = AwsGlueCatalogHook(region_name="us-west-2")
- self.assertEqual(hook.region_name, 'us-west-2')
+ assert hook.region_name == 'us-west-2'
@mock_glue
@mock.patch.object(AwsGlueCatalogHook, 'get_conn')
@@ -68,7 +69,7 @@ def test_get_partitions_empty(self, mock_get_conn):
mock_get_conn.get_paginator.paginate.return_value = response
hook = AwsGlueCatalogHook(region_name="us-east-1")
- self.assertEqual(hook.get_partitions('db', 'tbl'), set())
+ assert hook.get_partitions('db', 'tbl') == set()
@mock_glue
@mock.patch.object(AwsGlueCatalogHook, 'get_conn')
@@ -82,7 +83,7 @@ def test_get_partitions(self, mock_get_conn):
hook = AwsGlueCatalogHook(region_name="us-east-1")
result = hook.get_partitions('db', 'tbl', expression='foo=bar', page_size=2, max_items=3)
- self.assertEqual(result, {('2015-01-01',)})
+ assert result == {('2015-01-01',)}
mock_conn.get_paginator.assert_called_once_with('get_partitions')
mock_paginator.paginate.assert_called_once_with(
DatabaseName='db',
@@ -97,7 +98,7 @@ def test_check_for_partition(self, mock_get_partitions):
mock_get_partitions.return_value = {('2018-01-01',)}
hook = AwsGlueCatalogHook(region_name="us-east-1")
- self.assertTrue(hook.check_for_partition('db', 'tbl', 'expr'))
+ assert hook.check_for_partition('db', 'tbl', 'expr')
mock_get_partitions.assert_called_once_with('db', 'tbl', 'expr', max_items=1)
@mock_glue
@@ -106,7 +107,7 @@ def test_check_for_partition_false(self, mock_get_partitions):
mock_get_partitions.return_value = set()
hook = AwsGlueCatalogHook(region_name="us-east-1")
- self.assertFalse(hook.check_for_partition('db', 'tbl', 'expr'))
+ assert not hook.check_for_partition('db', 'tbl', 'expr')
@mock_glue
def test_get_table_exists(self):
@@ -115,17 +116,15 @@ def test_get_table_exists(self):
result = self.hook.get_table(DB_NAME, TABLE_NAME)
- self.assertEqual(result['Name'], TABLE_INPUT['Name'])
- self.assertEqual(
- result['StorageDescriptor']['Location'], TABLE_INPUT['StorageDescriptor']['Location']
- )
+ assert result['Name'] == TABLE_INPUT['Name']
+ assert result['StorageDescriptor']['Location'] == TABLE_INPUT['StorageDescriptor']['Location']
@mock_glue
def test_get_table_not_exists(self):
self.client.create_database(DatabaseInput={'Name': DB_NAME})
self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT)
- with self.assertRaises(Exception):
+ with pytest.raises(Exception):
self.hook.get_table(DB_NAME, 'dummy_table')
@mock_glue
@@ -134,4 +133,4 @@ def test_get_table_location(self):
self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT)
result = self.hook.get_table_location(DB_NAME, TABLE_NAME)
- self.assertEqual(result, TABLE_INPUT['StorageDescriptor']['Location'])
+ assert result == TABLE_INPUT['StorageDescriptor']['Location']
diff --git a/tests/providers/amazon/aws/hooks/test_kinesis.py b/tests/providers/amazon/aws/hooks/test_kinesis.py
index 7f2e65dae89b1..52984325ea49c 100644
--- a/tests/providers/amazon/aws/hooks/test_kinesis.py
+++ b/tests/providers/amazon/aws/hooks/test_kinesis.py
@@ -34,7 +34,7 @@ def test_get_conn_returns_a_boto3_connection(self):
hook = AwsFirehoseHook(
aws_conn_id='aws_default', delivery_stream="test_airflow", region_name="us-east-1"
)
- self.assertIsNotNone(hook.get_conn())
+ assert hook.get_conn() is not None
@unittest.skipIf(mock_kinesis is None, 'mock_kinesis package not present')
@mock_kinesis
@@ -55,11 +55,11 @@ def test_insert_batch_records_kinesis_firehose(self):
)
stream_arn = response['DeliveryStreamARN']
- self.assertEqual(stream_arn, "arn:aws:firehose:us-east-1:123456789012:deliverystream/test_airflow")
+ assert stream_arn == "arn:aws:firehose:us-east-1:123456789012:deliverystream/test_airflow"
records = [{"Data": str(uuid.uuid4())} for _ in range(100)]
response = hook.put_records(records)
- self.assertEqual(response['FailedPutCount'], 0)
- self.assertEqual(response['ResponseMetadata']['HTTPStatusCode'], 200)
+ assert response['FailedPutCount'] == 0
+ assert response['ResponseMetadata']['HTTPStatusCode'] == 200
diff --git a/tests/providers/amazon/aws/hooks/test_logs.py b/tests/providers/amazon/aws/hooks/test_logs.py
index 65852aa458f2f..48a78edcebc97 100644
--- a/tests/providers/amazon/aws/hooks/test_logs.py
+++ b/tests/providers/amazon/aws/hooks/test_logs.py
@@ -32,7 +32,7 @@ class TestAwsLogsHook(unittest.TestCase):
@mock_logs
def test_get_conn_returns_a_boto3_connection(self):
hook = AwsLogsHook(aws_conn_id='aws_default', region_name="us-east-1")
- self.assertIsNotNone(hook.get_conn())
+ assert hook.get_conn() is not None
@unittest.skipIf(mock_logs is None, 'mock_logs package not present')
# moto.logs does not support proper pagination so we cannot test that yet
diff --git a/tests/providers/amazon/aws/hooks/test_redshift.py b/tests/providers/amazon/aws/hooks/test_redshift.py
index 551a7c3b6daaf..3a3f94b4e8205 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift.py
@@ -57,7 +57,7 @@ def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
client_from_hook = hook.get_conn()
clusters = client_from_hook.describe_clusters()['Clusters']
- self.assertEqual(len(clusters), 2)
+ assert len(clusters) == 2
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
@mock_redshift
@@ -65,9 +65,9 @@ def test_restore_from_cluster_snapshot_returns_dict_with_cluster_data(self):
self._create_clusters()
hook = RedshiftHook(aws_conn_id='aws_default')
hook.create_cluster_snapshot('test_snapshot', 'test_cluster')
- self.assertEqual(
- hook.restore_from_cluster_snapshot('test_cluster_3', 'test_snapshot')['ClusterIdentifier'],
- 'test_cluster_3',
+ assert (
+ hook.restore_from_cluster_snapshot('test_cluster_3', 'test_snapshot')['ClusterIdentifier']
+ == 'test_cluster_3'
)
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
@@ -77,7 +77,7 @@ def test_delete_cluster_returns_a_dict_with_cluster_data(self):
hook = RedshiftHook(aws_conn_id='aws_default')
cluster = hook.delete_cluster('test_cluster_2')
- self.assertNotEqual(cluster, None)
+ assert cluster is not None
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
@mock_redshift
@@ -86,7 +86,7 @@ def test_create_cluster_snapshot_returns_snapshot_data(self):
hook = RedshiftHook(aws_conn_id='aws_default')
snapshot = hook.create_cluster_snapshot('test_snapshot_2', 'test_cluster')
- self.assertNotEqual(snapshot, None)
+ assert snapshot is not None
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
@mock_redshift
@@ -94,7 +94,7 @@ def test_cluster_status_returns_cluster_not_found(self):
self._create_clusters()
hook = RedshiftHook(aws_conn_id='aws_default')
status = hook.cluster_status('test_cluster_not_here')
- self.assertEqual(status, 'cluster_not_found')
+ assert status == 'cluster_not_found'
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
@mock_redshift
@@ -102,4 +102,4 @@ def test_cluster_status_returns_available_cluster(self):
self._create_clusters()
hook = RedshiftHook(aws_conn_id='aws_default')
status = hook.cluster_status('test_cluster')
- self.assertEqual(status, 'available')
+ assert status == 'available'
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py
index 055104d1e9e56..94805d6c849d6 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -336,10 +336,9 @@ def test_delete_bucket_if_bucket_exist(self, s3_bucket):
def test_delete_bucket_if_not_bucket_exist(self, s3_bucket):
# assert if exception is raised if bucket not present
mock_hook = S3Hook()
- with pytest.raises(ClientError) as error:
- # assert error
+ with pytest.raises(ClientError) as ctx:
assert mock_hook.delete_bucket(bucket_name=s3_bucket, force_delete=True)
- assert error.value.response['Error']['Code'] == 'NoSuchBucket'
+ assert ctx.value.response['Error']['Code'] == 'NoSuchBucket'
@mock.patch.object(S3Hook, 'get_connection', return_value=Connection(schema='test_bucket'))
def test_provide_bucket_name(self, mock_get_connection):
@@ -358,11 +357,11 @@ def test_function(self, bucket_name=None):
def test_delete_objects_key_does_not_exist(self, s3_bucket):
hook = S3Hook()
- with pytest.raises(AirflowException) as err:
+ with pytest.raises(AirflowException) as ctx:
hook.delete_objects(bucket=s3_bucket, keys=['key-1'])
- assert isinstance(err.value, AirflowException)
- assert str(err.value) == "Errors when deleting: ['key-1']"
+ assert isinstance(ctx.value, AirflowException)
+ assert str(ctx.value) == "Errors when deleting: ['key-1']"
def test_delete_objects_one_key(self, mocked_s3_res, s3_bucket):
key = 'key-1'
@@ -406,9 +405,9 @@ def test_function_with_test_key(self, test_key, bucket_name=None):
test_bucket_name_with_key = fake_s3_hook.test_function_with_key('s3://foo/bar.csv')
assert ('foo', 'bar.csv') == test_bucket_name_with_key
- with pytest.raises(ValueError) as err:
+ with pytest.raises(ValueError) as ctx:
fake_s3_hook.test_function_with_test_key('s3://foo/bar.csv')
- assert isinstance(err.value, ValueError)
+ assert isinstance(ctx.value, ValueError)
@mock.patch('airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile')
def test_download_file(self, mock_temp_file):
diff --git a/tests/providers/amazon/aws/hooks/test_sagemaker.py b/tests/providers/amazon/aws/hooks/test_sagemaker.py
index 5167e6a7e6abc..b714ab3784c42 100644
--- a/tests/providers/amazon/aws/hooks/test_sagemaker.py
+++ b/tests/providers/amazon/aws/hooks/test_sagemaker.py
@@ -22,6 +22,7 @@
from datetime import datetime
from unittest import mock
+import pytest
from dateutil.tz import tzlocal
from airflow.exceptions import AirflowException
@@ -211,7 +212,7 @@ def test_multi_stream_iter(self, mock_log_stream):
mock_log_stream.side_effect = [iter([event]), iter([]), None]
hook = SageMakerHook()
event_iter = hook.multi_stream_iter('log', [None, None, None])
- self.assertEqual(next(event_iter), (0, event))
+ assert next(event_iter) == (0, event)
@mock.patch.object(S3Hook, 'create_bucket')
@mock.patch.object(S3Hook, 'load_file')
@@ -219,7 +220,7 @@ def test_configure_s3_resources(self, mock_load_file, mock_create_bucket):
hook = SageMakerHook()
evaluation_result = {'Image': image, 'Role': role}
hook.configure_s3_resources(test_evaluation_config)
- self.assertEqual(test_evaluation_config, evaluation_result)
+ assert test_evaluation_config == evaluation_result
mock_create_bucket.assert_called_once_with(bucket_name=bucket)
mock_load_file.assert_called_once_with(path, key, bucket)
@@ -233,10 +234,12 @@ def test_check_s3_url(self, mock_check_prefix, mock_check_bucket, mock_check_key
mock_check_bucket.side_effect = [False, True, True, True]
mock_check_key.side_effect = [False, True, False]
mock_check_prefix.side_effect = [False, True, True]
- self.assertRaises(AirflowException, hook.check_s3_url, data_url)
- self.assertRaises(AirflowException, hook.check_s3_url, data_url)
- self.assertEqual(hook.check_s3_url(data_url), True)
- self.assertEqual(hook.check_s3_url(data_url), True)
+ with pytest.raises(AirflowException):
+ hook.check_s3_url(data_url)
+ with pytest.raises(AirflowException):
+ hook.check_s3_url(data_url)
+ assert hook.check_s3_url(data_url) is True
+ assert hook.check_s3_url(data_url) is True
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'check_s3_url')
@@ -262,7 +265,7 @@ def test_check_valid_tuning(self, mock_check_url, mock_client):
@mock.patch.object(SageMakerHook, 'get_client_type')
def test_conn(self, mock_get_client_type):
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
- self.assertEqual(hook.aws_conn_id, 'sagemaker_test_conn_id')
+ assert hook.aws_conn_id == 'sagemaker_test_conn_id'
@mock.patch.object(SageMakerHook, 'check_training_config')
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -277,7 +280,7 @@ def test_create_training_job(self, mock_client, mock_check_training):
create_training_params, wait_for_completion=False, print_log=False
)
mock_session.create_training_job.assert_called_once_with(**create_training_params)
- self.assertEqual(response, test_arn_return)
+ assert response == test_arn_return
@mock.patch.object(SageMakerHook, 'check_training_config')
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -299,7 +302,7 @@ def test_training_ends_with_wait(self, mock_client, mock_check_training):
hook.create_training_job(
create_training_params, wait_for_completion=True, print_log=False, check_interval=1
)
- self.assertEqual(mock_session.describe_training_job.call_count, 4)
+ assert mock_session.describe_training_job.call_count == 4
@mock.patch.object(SageMakerHook, 'check_training_config')
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -318,15 +321,14 @@ def test_training_throws_error_when_failed_with_wait(self, mock_client, mock_che
mock_session.configure_mock(**attrs)
mock_client.return_value = mock_session
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1')
- self.assertRaises(
- AirflowException,
- hook.create_training_job,
- create_training_params,
- wait_for_completion=True,
- print_log=False,
- check_interval=1,
- )
- self.assertEqual(mock_session.describe_training_job.call_count, 3)
+ with pytest.raises(AirflowException):
+ hook.create_training_job(
+ create_training_params,
+ wait_for_completion=True,
+ print_log=False,
+ check_interval=1,
+ )
+ assert mock_session.describe_training_job.call_count == 3
@mock.patch.object(SageMakerHook, 'check_tuning_config')
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -338,7 +340,7 @@ def test_create_tuning_job(self, mock_client, mock_check_tuning_config):
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
response = hook.create_tuning_job(create_tuning_params, wait_for_completion=False)
mock_session.create_hyper_parameter_tuning_job.assert_called_once_with(**create_tuning_params)
- self.assertEqual(response, test_arn_return)
+ assert response == test_arn_return
@mock.patch.object(SageMakerHook, 'check_s3_url')
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -351,7 +353,7 @@ def test_create_transform_job(self, mock_client, mock_check_url):
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
response = hook.create_transform_job(create_transform_params, wait_for_completion=False)
mock_session.create_transform_job.assert_called_once_with(**create_transform_params)
- self.assertEqual(response, test_arn_return)
+ assert response == test_arn_return
@mock.patch.object(SageMakerHook, 'get_conn')
def test_create_model(self, mock_client):
@@ -362,7 +364,7 @@ def test_create_model(self, mock_client):
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
response = hook.create_model(create_model_params)
mock_session.create_model.assert_called_once_with(**create_model_params)
- self.assertEqual(response, test_arn_return)
+ assert response == test_arn_return
@mock.patch.object(SageMakerHook, 'get_conn')
def test_create_endpoint_config(self, mock_client):
@@ -373,7 +375,7 @@ def test_create_endpoint_config(self, mock_client):
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
response = hook.create_endpoint_config(create_endpoint_config_params)
mock_session.create_endpoint_config.assert_called_once_with(**create_endpoint_config_params)
- self.assertEqual(response, test_arn_return)
+ assert response == test_arn_return
@mock.patch.object(SageMakerHook, 'get_conn')
def test_create_endpoint(self, mock_client):
@@ -384,7 +386,7 @@ def test_create_endpoint(self, mock_client):
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
response = hook.create_endpoint(create_endpoint_params, wait_for_completion=False)
mock_session.create_endpoint.assert_called_once_with(**create_endpoint_params)
- self.assertEqual(response, test_arn_return)
+ assert response == test_arn_return
@mock.patch.object(SageMakerHook, 'get_conn')
def test_update_endpoint(self, mock_client):
@@ -395,7 +397,7 @@ def test_update_endpoint(self, mock_client):
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
response = hook.update_endpoint(update_endpoint_params, wait_for_completion=False)
mock_session.update_endpoint.assert_called_once_with(**update_endpoint_params)
- self.assertEqual(response, test_arn_return)
+ assert response == test_arn_return
@mock.patch.object(SageMakerHook, 'get_conn')
def test_describe_training_job(self, mock_client):
@@ -406,7 +408,7 @@ def test_describe_training_job(self, mock_client):
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
response = hook.describe_training_job(job_name)
mock_session.describe_training_job.assert_called_once_with(TrainingJobName=job_name)
- self.assertEqual(response, 'InProgress')
+ assert response == 'InProgress'
@mock.patch.object(SageMakerHook, 'get_conn')
def test_describe_tuning_job(self, mock_client):
@@ -419,7 +421,7 @@ def test_describe_tuning_job(self, mock_client):
mock_session.describe_hyper_parameter_tuning_job.assert_called_once_with(
HyperParameterTuningJobName=job_name
)
- self.assertEqual(response, 'InProgress')
+ assert response == 'InProgress'
@mock.patch.object(SageMakerHook, 'get_conn')
def test_describe_transform_job(self, mock_client):
@@ -430,7 +432,7 @@ def test_describe_transform_job(self, mock_client):
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
response = hook.describe_transform_job(job_name)
mock_session.describe_transform_job.assert_called_once_with(TransformJobName=job_name)
- self.assertEqual(response, 'InProgress')
+ assert response == 'InProgress'
@mock.patch.object(SageMakerHook, 'get_conn')
def test_describe_model(self, mock_client):
@@ -441,7 +443,7 @@ def test_describe_model(self, mock_client):
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
response = hook.describe_model(model_name)
mock_session.describe_model.assert_called_once_with(ModelName=model_name)
- self.assertEqual(response, model_name)
+ assert response == model_name
@mock.patch.object(SageMakerHook, 'get_conn')
def test_describe_endpoint_config(self, mock_client):
@@ -452,7 +454,7 @@ def test_describe_endpoint_config(self, mock_client):
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
response = hook.describe_endpoint_config(config_name)
mock_session.describe_endpoint_config.assert_called_once_with(EndpointConfigName=config_name)
- self.assertEqual(response, config_name)
+ assert response == config_name
@mock.patch.object(SageMakerHook, 'get_conn')
def test_describe_endpoint(self, mock_client):
@@ -463,19 +465,19 @@ def test_describe_endpoint(self, mock_client):
hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
response = hook.describe_endpoint(endpoint_name)
mock_session.describe_endpoint.assert_called_once_with(EndpointName=endpoint_name)
- self.assertEqual(response, 'InProgress')
+ assert response == 'InProgress'
def test_secondary_training_status_changed_true(self):
changed = secondary_training_status_changed(
SECONDARY_STATUS_DESCRIPTION_1, SECONDARY_STATUS_DESCRIPTION_2
)
- self.assertTrue(changed)
+ assert changed
def test_secondary_training_status_changed_false(self):
changed = secondary_training_status_changed(
SECONDARY_STATUS_DESCRIPTION_1, SECONDARY_STATUS_DESCRIPTION_1
)
- self.assertFalse(changed)
+ assert not changed
def test_secondary_training_status_message_status_changed(self):
now = datetime.now(tzlocal())
@@ -485,9 +487,9 @@ def test_secondary_training_status_message_status_changed(self):
status,
message,
)
- self.assertEqual(
- secondary_training_status_message(SECONDARY_STATUS_DESCRIPTION_1, SECONDARY_STATUS_DESCRIPTION_2),
- expected,
+ assert (
+ secondary_training_status_message(SECONDARY_STATUS_DESCRIPTION_1, SECONDARY_STATUS_DESCRIPTION_2)
+ == expected
)
@mock.patch.object(AwsLogsHook, 'get_conn')
@@ -516,7 +518,7 @@ def test_describe_training_job_with_logs_in_progress(self, mock_time, mock_clien
last_description={},
last_describe_job_call=0,
)
- self.assertEqual(response, (LogState.JOB_COMPLETE, {}, 50))
+ assert response == (LogState.JOB_COMPLETE, {}, 50)
@mock.patch.object(AwsLogsHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -542,7 +544,7 @@ def test_describe_training_job_with_logs_job_complete(self, mock_client, mock_lo
last_description={},
last_describe_job_call=0,
)
- self.assertEqual(response, (LogState.COMPLETE, {}, 0))
+ assert response == (LogState.COMPLETE, {}, 0)
@mock.patch.object(AwsLogsHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -568,7 +570,7 @@ def test_describe_training_job_with_logs_complete(self, mock_client, mock_log_cl
last_description={},
last_describe_job_call=0,
)
- self.assertEqual(response, (LogState.COMPLETE, {}, 0))
+ assert response == (LogState.COMPLETE, {}, 0)
@mock.patch.object(SageMakerHook, 'check_training_config')
@mock.patch.object(AwsLogsHook, 'get_conn')
@@ -599,5 +601,5 @@ def test_training_with_logs(self, mock_describe, mock_client, mock_log_client, m
hook.create_training_job(
create_training_params, wait_for_completion=True, print_log=True, check_interval=1
)
- self.assertEqual(mock_describe.call_count, 3)
- self.assertEqual(mock_session.describe_training_job.call_count, 1)
+ assert mock_describe.call_count == 3
+ assert mock_session.describe_training_job.call_count == 1
diff --git a/tests/providers/amazon/aws/hooks/test_secrets_manager.py b/tests/providers/amazon/aws/hooks/test_secrets_manager.py
index 916a5ea04def0..bfcd847aa7593 100644
--- a/tests/providers/amazon/aws/hooks/test_secrets_manager.py
+++ b/tests/providers/amazon/aws/hooks/test_secrets_manager.py
@@ -34,7 +34,7 @@ class TestSecretsManagerHook(unittest.TestCase):
@mock_secretsmanager
def test_get_conn_returns_a_boto3_connection(self):
hook = SecretsManagerHook(aws_conn_id='aws_default')
- self.assertIsNotNone(hook.get_conn())
+ assert hook.get_conn() is not None
@unittest.skipIf(mock_secretsmanager is None, 'mock_secretsmanager package not present')
@mock_secretsmanager
@@ -51,7 +51,7 @@ def test_get_secret_string(self):
hook.get_conn().put_secret_value(**param)
secret = hook.get_secret(secret_name)
- self.assertEqual(secret, secret_value)
+ assert secret == secret_value
@unittest.skipIf(mock_secretsmanager is None, 'mock_secretsmanager package not present')
@mock_secretsmanager
@@ -68,7 +68,7 @@ def test_get_secret_dict(self):
hook.get_conn().put_secret_value(**param)
secret = hook.get_secret_as_dict(secret_name)
- self.assertEqual(secret, json.loads(secret_value))
+ assert secret == json.loads(secret_value)
@unittest.skipIf(mock_secretsmanager is None, 'mock_secretsmanager package not present')
@mock_secretsmanager
@@ -85,4 +85,4 @@ def test_get_secret_binary(self):
hook.get_conn().put_secret_value(**param)
secret = hook.get_secret(secret_name)
- self.assertEqual(secret, base64.b64decode(secret_value_binary))
+ assert secret == base64.b64decode(secret_value_binary)
diff --git a/tests/providers/amazon/aws/hooks/test_sns.py b/tests/providers/amazon/aws/hooks/test_sns.py
index f14fde6b9ac71..3c1a72fb6374e 100644
--- a/tests/providers/amazon/aws/hooks/test_sns.py
+++ b/tests/providers/amazon/aws/hooks/test_sns.py
@@ -32,7 +32,7 @@ class TestAwsSnsHook(unittest.TestCase):
@mock_sns
def test_get_conn_returns_a_boto3_connection(self):
hook = AwsSnsHook(aws_conn_id='aws_default')
- self.assertIsNotNone(hook.get_conn())
+ assert hook.get_conn() is not None
@mock_sns
def test_publish_to_target_with_subject(self):
diff --git a/tests/providers/amazon/aws/hooks/test_sqs.py b/tests/providers/amazon/aws/hooks/test_sqs.py
index 4b85af9d7cc15..b4e3e8be07d65 100644
--- a/tests/providers/amazon/aws/hooks/test_sqs.py
+++ b/tests/providers/amazon/aws/hooks/test_sqs.py
@@ -32,4 +32,4 @@ class TestAwsSQSHook(unittest.TestCase):
@mock_sqs
def test_get_conn(self):
hook = SQSHook(aws_conn_id='aws_default')
- self.assertIsNotNone(hook.get_conn())
+ assert hook.get_conn() is not None
diff --git a/tests/providers/amazon/aws/hooks/test_step_function.py b/tests/providers/amazon/aws/hooks/test_step_function.py
index 6a0cb40fbade5..e808d5f382567 100644
--- a/tests/providers/amazon/aws/hooks/test_step_function.py
+++ b/tests/providers/amazon/aws/hooks/test_step_function.py
@@ -32,7 +32,7 @@ class TestStepFunctionHook(unittest.TestCase):
@mock_stepfunctions
def test_get_conn_returns_a_boto3_connection(self):
hook = StepFunctionHook(aws_conn_id='aws_default')
- self.assertEqual('stepfunctions', hook.get_conn().meta.service_model.service_name)
+ assert 'stepfunctions' == hook.get_conn().meta.service_model.service_name
@mock_stepfunctions
def test_start_execution(self):
diff --git a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
index 12079bd9b9d31..827e69f766ab8 100644
--- a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
+++ b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
@@ -74,7 +74,7 @@ def tearDown(self):
self.cloudwatch_task_handler.handler = None
def test_hook(self):
- self.assertIsInstance(self.cloudwatch_task_handler.hook, AwsLogsHook)
+ assert isinstance(self.cloudwatch_task_handler.hook, AwsLogsHook)
@conf_vars({('logging', 'remote_log_conn_id'): 'aws_default'})
def test_hook_raises(self):
@@ -98,7 +98,7 @@ def test_hook_raises(self):
def test_handler(self):
self.cloudwatch_task_handler.set_context(self.ti)
- self.assertIsInstance(self.cloudwatch_task_handler.handler, CloudWatchLogHandler)
+ assert isinstance(self.cloudwatch_task_handler.handler, CloudWatchLogHandler)
def test_write(self):
handler = self.cloudwatch_task_handler
@@ -129,12 +129,9 @@ def test_read(self):
expected = (
'*** Reading remote log from Cloudwatch log_group: {} log_stream: {}.\nFirst\nSecond\nThird\n'
)
- self.assertEqual(
- self.cloudwatch_task_handler.read(self.ti),
- (
- [[('', expected.format(self.remote_log_group, self.remote_log_stream))]],
- [{'end_of_log': True}],
- ),
+ assert self.cloudwatch_task_handler.read(self.ti) == (
+ [[('', expected.format(self.remote_log_group, self.remote_log_stream))]],
+ [{'end_of_log': True}],
)
def test_read_wrong_log_stream(self):
@@ -153,12 +150,9 @@ def test_read_wrong_log_stream(self):
error_msg = 'Could not read remote logs from log_group: {} log_stream: {}.'.format(
self.remote_log_group, self.remote_log_stream
)
- self.assertEqual(
- self.cloudwatch_task_handler.read(self.ti),
- (
- [[('', msg_template.format(self.remote_log_group, self.remote_log_stream, error_msg))]],
- [{'end_of_log': True}],
- ),
+ assert self.cloudwatch_task_handler.read(self.ti) == (
+ [[('', msg_template.format(self.remote_log_group, self.remote_log_stream, error_msg))]],
+ [{'end_of_log': True}],
)
def test_read_wrong_log_group(self):
@@ -177,12 +171,9 @@ def test_read_wrong_log_group(self):
error_msg = 'Could not read remote logs from log_group: {} log_stream: {}.'.format(
self.remote_log_group, self.remote_log_stream
)
- self.assertEqual(
- self.cloudwatch_task_handler.read(self.ti),
- (
- [[('', msg_template.format(self.remote_log_group, self.remote_log_stream, error_msg))]],
- [{'end_of_log': True}],
- ),
+ assert self.cloudwatch_task_handler.read(self.ti) == (
+ [[('', msg_template.format(self.remote_log_group, self.remote_log_stream, error_msg))]],
+ [{'end_of_log': True}],
)
def test_close_prevents_duplicate_calls(self):
diff --git a/tests/providers/amazon/aws/log/test_s3_task_handler.py b/tests/providers/amazon/aws/log/test_s3_task_handler.py
index b312342565288..c35339637993a 100644
--- a/tests/providers/amazon/aws/log/test_s3_task_handler.py
+++ b/tests/providers/amazon/aws/log/test_s3_task_handler.py
@@ -20,6 +20,7 @@
import unittest
from unittest import mock
+import pytest
from botocore.exceptions import ClientError
from airflow.models import DAG, TaskInstance
@@ -77,7 +78,7 @@ def tearDown(self):
pass
def test_hook(self):
- self.assertIsInstance(self.s3_task_handler.hook, S3Hook)
+ assert isinstance(self.s3_task_handler.hook, S3Hook)
@conf_vars({('logging', 'remote_log_conn_id'): 'aws_default'})
def test_hook_raises(self):
@@ -97,18 +98,18 @@ def test_hook_raises(self):
def test_log_exists(self):
self.conn.put_object(Bucket='bucket', Key=self.remote_log_key, Body=b'')
- self.assertTrue(self.s3_task_handler.s3_log_exists(self.remote_log_location))
+ assert self.s3_task_handler.s3_log_exists(self.remote_log_location)
def test_log_exists_none(self):
- self.assertFalse(self.s3_task_handler.s3_log_exists(self.remote_log_location))
+ assert not self.s3_task_handler.s3_log_exists(self.remote_log_location)
def test_log_exists_raises(self):
- self.assertFalse(self.s3_task_handler.s3_log_exists('s3://nonexistentbucket/foo'))
+ assert not self.s3_task_handler.s3_log_exists('s3://nonexistentbucket/foo')
def test_log_exists_no_hook(self):
with mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook") as mock_hook:
mock_hook.side_effect = Exception('Failed to connect')
- with self.assertRaises(Exception):
+ with pytest.raises(Exception):
self.s3_task_handler.s3_log_exists(self.remote_log_location)
def test_set_context_raw(self):
@@ -117,7 +118,7 @@ def test_set_context_raw(self):
with mock.patch('airflow.providers.amazon.aws.log.s3_task_handler.open', mock_open):
self.s3_task_handler.set_context(self.ti)
- self.assertFalse(self.s3_task_handler.upload_on_close)
+ assert not self.s3_task_handler.upload_on_close
mock_open.assert_not_called()
def test_set_context_not_raw(self):
@@ -125,26 +126,26 @@ def test_set_context_not_raw(self):
with mock.patch('airflow.providers.amazon.aws.log.s3_task_handler.open', mock_open):
self.s3_task_handler.set_context(self.ti)
- self.assertTrue(self.s3_task_handler.upload_on_close)
+ assert self.s3_task_handler.upload_on_close
mock_open.assert_called_once_with(os.path.abspath('local/log/location/1.log'), 'w')
mock_open().write.assert_not_called()
def test_read(self):
self.conn.put_object(Bucket='bucket', Key=self.remote_log_key, Body=b'Log line\n')
log, metadata = self.s3_task_handler.read(self.ti)
- self.assertEqual(
- log[0][0][-1],
- '*** Reading remote log from s3://bucket/remote/log/location/1.log.\nLog line\n\n',
+ assert (
+ log[0][0][-1]
+ == '*** Reading remote log from s3://bucket/remote/log/location/1.log.\nLog line\n\n'
)
- self.assertEqual(metadata, [{'end_of_log': True}])
+ assert metadata == [{'end_of_log': True}]
def test_read_when_s3_log_missing(self):
log, metadata = self.s3_task_handler.read(self.ti)
- self.assertEqual(1, len(log))
- self.assertEqual(len(log), len(metadata))
- self.assertIn('*** Log file does not exist:', log[0][0][-1])
- self.assertEqual({'end_of_log': True}, metadata[0])
+ assert 1 == len(log)
+ assert len(log) == len(metadata)
+ assert '*** Log file does not exist:' in log[0][0][-1]
+ assert {'end_of_log': True} == metadata[0]
def test_s3_read_when_log_missing(self):
handler = self.s3_task_handler
@@ -155,7 +156,7 @@ def test_s3_read_when_log_missing(self):
f'Could not read logs from {url} with error: An error occurred (404) when calling the '
f'HeadObject operation: Not Found'
)
- self.assertEqual(result, msg)
+ assert result == msg
mock_error.assert_called_once_with(msg, exc_info=True)
def test_read_raises_return_error(self):
@@ -167,7 +168,7 @@ def test_read_raises_return_error(self):
f'Could not read logs from {url} with error: An error occurred (NoSuchBucket) when '
f'calling the HeadObject operation: The specified bucket does not exist'
)
- self.assertEqual(result, msg)
+ assert result == msg
mock_error.assert_called_once_with(msg, exc_info=True)
def test_write(self):
@@ -182,7 +183,7 @@ def test_write(self):
.read()
)
- self.assertEqual(body, b'text')
+ assert body == b'text'
def test_write_existing(self):
self.conn.put_object(Bucket='bucket', Key=self.remote_log_key, Body=b'previous ')
@@ -194,19 +195,18 @@ def test_write_existing(self):
.read()
)
- self.assertEqual(body, b'previous \ntext')
+ assert body == b'previous \ntext'
def test_write_raises(self):
handler = self.s3_task_handler
url = 's3://nonexistentbucket/foo'
with mock.patch.object(handler.log, 'error') as mock_error:
handler.s3_write('text', url)
- self.assertEqual
mock_error.assert_called_once_with('Could not write logs to %s', url, exc_info=True)
def test_close(self):
self.s3_task_handler.set_context(self.ti)
- self.assertTrue(self.s3_task_handler.upload_on_close)
+ assert self.s3_task_handler.upload_on_close
self.s3_task_handler.close()
# Should not raise
@@ -215,8 +215,8 @@ def test_close(self):
def test_close_no_upload(self):
self.ti.raw = True
self.s3_task_handler.set_context(self.ti)
- self.assertFalse(self.s3_task_handler.upload_on_close)
+ assert not self.s3_task_handler.upload_on_close
self.s3_task_handler.close()
- with self.assertRaises(ClientError):
+ with pytest.raises(ClientError):
boto3.resource('s3').Object('bucket', self.remote_log_key).get() # pylint: disable=no-member
diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py
index 61e43d26aec1c..c1dd723447192 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -18,6 +18,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.models import DAG, TaskInstance
from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook
from airflow.providers.amazon.aws.operators.athena import AWSAthenaOperator
@@ -62,14 +64,14 @@ def setUp(self):
)
def test_init(self):
- self.assertEqual(self.athena.task_id, MOCK_DATA['task_id'])
- self.assertEqual(self.athena.query, MOCK_DATA['query'])
- self.assertEqual(self.athena.database, MOCK_DATA['database'])
- self.assertEqual(self.athena.aws_conn_id, 'aws_default')
- self.assertEqual(self.athena.client_request_token, MOCK_DATA['client_request_token'])
- self.assertEqual(self.athena.sleep_time, 0)
+ assert self.athena.task_id == MOCK_DATA['task_id']
+ assert self.athena.query == MOCK_DATA['query']
+ assert self.athena.database == MOCK_DATA['database']
+ assert self.athena.aws_conn_id == 'aws_default'
+ assert self.athena.client_request_token == MOCK_DATA['client_request_token']
+ assert self.athena.sleep_time == 0
- self.assertEqual(self.athena.hook.sleep_time, 0)
+ assert self.athena.hook.sleep_time == 0
@mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("SUCCESS",))
@mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
@@ -83,7 +85,7 @@ def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_chec
MOCK_DATA['client_request_token'],
MOCK_DATA['workgroup'],
)
- self.assertEqual(mock_check_query_status.call_count, 1)
+ assert mock_check_query_status.call_count == 1
@mock.patch.object(
AWSAthenaHook,
@@ -105,7 +107,7 @@ def test_hook_run_big_success_query(self, mock_conn, mock_run_query, mock_check_
MOCK_DATA['client_request_token'],
MOCK_DATA['workgroup'],
)
- self.assertEqual(mock_check_query_status.call_count, 3)
+ assert mock_check_query_status.call_count == 3
@mock.patch.object(
AWSAthenaHook,
@@ -118,7 +120,7 @@ def test_hook_run_big_success_query(self, mock_conn, mock_run_query, mock_check_
@mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_run_failed_query_with_none(self, mock_conn, mock_run_query, mock_check_query_status):
- with self.assertRaises(Exception):
+ with pytest.raises(Exception):
self.athena.execute(None)
mock_run_query.assert_called_once_with(
MOCK_DATA['query'],
@@ -127,7 +129,7 @@ def test_hook_run_failed_query_with_none(self, mock_conn, mock_run_query, mock_c
MOCK_DATA['client_request_token'],
MOCK_DATA['workgroup'],
)
- self.assertEqual(mock_check_query_status.call_count, 3)
+ assert mock_check_query_status.call_count == 3
@mock.patch.object(AWSAthenaHook, 'get_state_change_reason')
@mock.patch.object(
@@ -143,7 +145,7 @@ def test_hook_run_failed_query_with_none(self, mock_conn, mock_run_query, mock_c
def test_hook_run_failure_query(
self, mock_conn, mock_run_query, mock_check_query_status, mock_get_state_change_reason
):
- with self.assertRaises(Exception):
+ with pytest.raises(Exception):
self.athena.execute(None)
mock_run_query.assert_called_once_with(
MOCK_DATA['query'],
@@ -152,8 +154,8 @@ def test_hook_run_failure_query(
MOCK_DATA['client_request_token'],
MOCK_DATA['workgroup'],
)
- self.assertEqual(mock_check_query_status.call_count, 2)
- self.assertEqual(mock_get_state_change_reason.call_count, 1)
+ assert mock_check_query_status.call_count == 2
+ assert mock_get_state_change_reason.call_count == 1
@mock.patch.object(
AWSAthenaHook,
@@ -167,7 +169,7 @@ def test_hook_run_failure_query(
@mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_query_status):
- with self.assertRaises(Exception):
+ with pytest.raises(Exception):
self.athena.execute(None)
mock_run_query.assert_called_once_with(
MOCK_DATA['query'],
@@ -176,7 +178,7 @@ def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_qu
MOCK_DATA['client_request_token'],
MOCK_DATA['workgroup'],
)
- self.assertEqual(mock_check_query_status.call_count, 3)
+ assert mock_check_query_status.call_count == 3
@mock.patch.object(
AWSAthenaHook,
@@ -190,7 +192,7 @@ def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_qu
@mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_run_failed_query_with_max_tries(self, mock_conn, mock_run_query, mock_check_query_status):
- with self.assertRaises(Exception):
+ with pytest.raises(Exception):
self.athena.execute(None)
mock_run_query.assert_called_once_with(
MOCK_DATA['query'],
@@ -199,7 +201,7 @@ def test_hook_run_failed_query_with_max_tries(self, mock_conn, mock_run_query, m
MOCK_DATA['client_request_token'],
MOCK_DATA['workgroup'],
)
- self.assertEqual(mock_check_query_status.call_count, 3)
+ assert mock_check_query_status.call_count == 3
@mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("SUCCESS",))
@mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
@@ -208,7 +210,7 @@ def test_xcom_push_and_pull(self, mock_conn, mock_run_query, mock_check_query_st
ti = TaskInstance(task=self.athena, execution_date=timezone.utcnow())
ti.run()
- self.assertEqual(ti.xcom_pull(task_ids='test_aws_athena_operator'), ATHENA_QUERY_ID)
+ assert ti.xcom_pull(task_ids='test_aws_athena_operator') == ATHENA_QUERY_ID
# pylint: enable=unused-argument
diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py
index 76be1fe414136..2c5b9c83e715d 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -22,6 +22,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import AwsBatchClientHook
from airflow.providers.amazon.aws.operators.batch import AwsBatchOperator
@@ -66,7 +68,7 @@ def setUp(self, get_client_type_mock):
tags={},
)
self.client_mock = self.get_client_type_mock.return_value
- self.assertEqual(self.batch.hook.client, self.client_mock) # setup client property
+ assert self.batch.hook.client == self.client_mock # setup client property
# don't pause in unit tests
self.mock_delay = mock.Mock(return_value=None)
@@ -75,35 +77,32 @@ def setUp(self, get_client_type_mock):
self.batch.exponential_delay = self.mock_exponential_delay
# Assign a job ID for most tests, so they don't depend on a job submission.
- self.assertIsNone(self.batch.job_id)
+ assert self.batch.job_id is None
self.batch.job_id = JOB_ID
def test_init(self):
- self.assertEqual(self.batch.job_id, JOB_ID)
- self.assertEqual(self.batch.job_name, JOB_NAME)
- self.assertEqual(self.batch.job_queue, "queue")
- self.assertEqual(self.batch.job_definition, "hello-world")
- self.assertEqual(self.batch.waiters, None)
- self.assertEqual(self.batch.hook.max_retries, self.MAX_RETRIES)
- self.assertEqual(self.batch.hook.status_retries, self.STATUS_RETRIES)
- self.assertEqual(self.batch.parameters, {})
- self.assertEqual(self.batch.overrides, {})
- self.assertEqual(self.batch.array_properties, {})
- self.assertEqual(self.batch.hook.region_name, "eu-west-1")
- self.assertEqual(self.batch.hook.aws_conn_id, "airflow_test")
- self.assertEqual(self.batch.hook.client, self.client_mock)
- self.assertEqual(self.batch.tags, {})
+ assert self.batch.job_id == JOB_ID
+ assert self.batch.job_name == JOB_NAME
+ assert self.batch.job_queue == "queue"
+ assert self.batch.job_definition == "hello-world"
+ assert self.batch.waiters is None
+ assert self.batch.hook.max_retries == self.MAX_RETRIES
+ assert self.batch.hook.status_retries == self.STATUS_RETRIES
+ assert self.batch.parameters == {}
+ assert self.batch.overrides == {}
+ assert self.batch.array_properties == {}
+ assert self.batch.hook.region_name == "eu-west-1"
+ assert self.batch.hook.aws_conn_id == "airflow_test"
+ assert self.batch.hook.client == self.client_mock
+ assert self.batch.tags == {}
self.get_client_type_mock.assert_called_once_with("batch", region_name="eu-west-1")
def test_template_fields_overrides(self):
- self.assertEqual(
- self.batch.template_fields,
- (
- "job_name",
- "overrides",
- "parameters",
- ),
+ assert self.batch.template_fields == (
+ "job_name",
+ "overrides",
+ "parameters",
)
@mock.patch.object(AwsBatchClientHook, "wait_for_job")
@@ -126,14 +125,14 @@ def test_execute_without_failures(self, check_mock, wait_mock):
tags={},
)
- self.assertEqual(self.batch.job_id, JOB_ID)
+ assert self.batch.job_id == JOB_ID
wait_mock.assert_called_once_with(JOB_ID)
check_mock.assert_called_once_with(JOB_ID)
def test_execute_with_failures(self):
self.client_mock.submit_job.return_value = ""
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.batch.execute(None)
self.client_mock.submit_job.assert_called_once_with(
diff --git a/tests/providers/amazon/aws/operators/test_datasync.py b/tests/providers/amazon/aws/operators/test_datasync.py
index 8885f9161f4a9..587196cba1c13 100644
--- a/tests/providers/amazon/aws/operators/test_datasync.py
+++ b/tests/providers/amazon/aws/operators/test_datasync.py
@@ -18,6 +18,7 @@
from unittest import mock
import boto3
+import pytest
from moto import mock_datasync
from airflow.exceptions import AirflowException
@@ -146,27 +147,21 @@ def set_up_operator(
def test_init(self, mock_get_conn):
self.set_up_operator()
# Airflow built-ins
- self.assertEqual(self.datasync.task_id, MOCK_DATA["create_task_id"])
+ assert self.datasync.task_id == MOCK_DATA["create_task_id"]
# Defaults
- self.assertEqual(self.datasync.aws_conn_id, "aws_default")
- self.assertFalse(self.datasync.allow_random_task_choice)
- self.assertFalse(self.datasync.task_execution_kwargs) # Empty dict
+ assert self.datasync.aws_conn_id == "aws_default"
+ assert not self.datasync.allow_random_task_choice
+ assert not self.datasync.task_execution_kwargs # Empty dict
# Assignments
- self.assertEqual(self.datasync.source_location_uri, MOCK_DATA["source_location_uri"])
- self.assertEqual(
- self.datasync.destination_location_uri,
- MOCK_DATA["destination_location_uri"],
+ assert self.datasync.source_location_uri == MOCK_DATA["source_location_uri"]
+ assert self.datasync.destination_location_uri == MOCK_DATA["destination_location_uri"]
+ assert self.datasync.create_task_kwargs == MOCK_DATA["create_task_kwargs"]
+ assert self.datasync.create_source_location_kwargs == MOCK_DATA["create_source_location_kwargs"]
+ assert (
+ self.datasync.create_destination_location_kwargs
+ == MOCK_DATA["create_destination_location_kwargs"]
)
- self.assertEqual(self.datasync.create_task_kwargs, MOCK_DATA["create_task_kwargs"])
- self.assertEqual(
- self.datasync.create_source_location_kwargs,
- MOCK_DATA["create_source_location_kwargs"],
- )
- self.assertEqual(
- self.datasync.create_destination_location_kwargs,
- MOCK_DATA["create_destination_location_kwargs"],
- )
- self.assertFalse(self.datasync.allow_random_location_choice)
+ assert not self.datasync.allow_random_location_choice
# ### Check mocks:
mock_get_conn.assert_not_called()
@@ -175,11 +170,11 @@ def test_init_fails(self, mock_get_conn):
mock_get_conn.return_value = self.client
# ### Begin tests:
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.set_up_operator(source_location_uri=None)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.set_up_operator(destination_location_uri=None)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.set_up_operator(source_location_uri=None, destination_location_uri=None)
# ### Check mocks:
mock_get_conn.assert_not_called()
@@ -197,24 +192,24 @@ def test_create_task(self, mock_get_conn):
# Check how many tasks and locations we have
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks["Tasks"]), 0)
+ assert len(tasks["Tasks"]) == 0
locations = self.client.list_locations()
- self.assertEqual(len(locations["Locations"]), 2)
+ assert len(locations["Locations"]) == 2
# Execute the task
result = self.datasync.execute(None)
- self.assertIsNotNone(result)
+ assert result is not None
task_arn = result["TaskArn"]
# Assert 1 additional task and 0 additional locations
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks["Tasks"]), 1)
+ assert len(tasks["Tasks"]) == 1
locations = self.client.list_locations()
- self.assertEqual(len(locations["Locations"]), 2)
+ assert len(locations["Locations"]) == 2
# Check task metadata
task = self.client.describe_task(TaskArn=task_arn)
- self.assertEqual(task["Options"], CREATE_TASK_KWARGS["Options"])
+ assert task["Options"] == CREATE_TASK_KWARGS["Options"]
# ### Check mocks:
mock_get_conn.assert_called()
@@ -235,19 +230,19 @@ def test_create_task_and_location(self, mock_get_conn):
# Check how many tasks and locations we have
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks["Tasks"]), 0)
+ assert len(tasks["Tasks"]) == 0
locations = self.client.list_locations()
- self.assertEqual(len(locations["Locations"]), 0)
+ assert len(locations["Locations"]) == 0
# Execute the task
result = self.datasync.execute(None)
- self.assertIsNotNone(result)
+ assert result is not None
# Assert 1 additional task and 2 additional locations
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks["Tasks"]), 1)
+ assert len(tasks["Tasks"]) == 1
locations = self.client.list_locations()
- self.assertEqual(len(locations["Locations"]), 2)
+ assert len(locations["Locations"]) == 2
# ### Check mocks:
mock_get_conn.assert_called()
@@ -264,7 +259,7 @@ def test_dont_create_task(self, mock_get_conn):
tasks = self.client.list_tasks()
tasks_after = len(tasks["Tasks"])
- self.assertEqual(tasks_before, tasks_after)
+ assert tasks_before == tasks_after
# ### Check mocks:
mock_get_conn.assert_called()
@@ -281,7 +276,7 @@ def test_create_task_many_locations(self, mock_get_conn):
self.client.create_location_smb(**MOCK_DATA["create_source_location_kwargs"])
self.set_up_operator(task_id='datasync_task1')
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.datasync.execute(None)
# Delete all tasks:
@@ -306,8 +301,8 @@ def test_execute_specific_task(self, mock_get_conn):
self.set_up_operator(task_arn=task_arn)
result = self.datasync.execute(None)
- self.assertEqual(result["TaskArn"], task_arn)
- self.assertEqual(self.datasync.task_arn, task_arn)
+ assert result["TaskArn"] == task_arn
+ assert self.datasync.task_arn == task_arn
# ### Check mocks:
mock_get_conn.assert_called()
@@ -320,7 +315,7 @@ def test_xcom_push(self, mock_get_conn):
ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow())
ti.run()
xcom_result = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value")
- self.assertIsNotNone(xcom_result)
+ assert xcom_result is not None
# ### Check mocks:
mock_get_conn.assert_called()
@@ -353,17 +348,14 @@ def set_up_operator(
def test_init(self, mock_get_conn):
self.set_up_operator()
# Airflow built-ins
- self.assertEqual(self.datasync.task_id, MOCK_DATA["get_task_id"])
+ assert self.datasync.task_id == MOCK_DATA["get_task_id"]
# Defaults
- self.assertEqual(self.datasync.aws_conn_id, "aws_default")
- self.assertFalse(self.datasync.allow_random_location_choice)
+ assert self.datasync.aws_conn_id == "aws_default"
+ assert not self.datasync.allow_random_location_choice
# Assignments
- self.assertEqual(self.datasync.source_location_uri, MOCK_DATA["source_location_uri"])
- self.assertEqual(
- self.datasync.destination_location_uri,
- MOCK_DATA["destination_location_uri"],
- )
- self.assertFalse(self.datasync.allow_random_task_choice)
+ assert self.datasync.source_location_uri == MOCK_DATA["source_location_uri"]
+ assert self.datasync.destination_location_uri == MOCK_DATA["destination_location_uri"]
+ assert not self.datasync.allow_random_task_choice
# ### Check mocks:
mock_get_conn.assert_not_called()
@@ -372,11 +364,11 @@ def test_init_fails(self, mock_get_conn):
mock_get_conn.return_value = self.client
# ### Begin tests:
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.set_up_operator(source_location_uri=None)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.set_up_operator(destination_location_uri=None)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.set_up_operator(source_location_uri=None, destination_location_uri=None)
# ### Check mocks:
mock_get_conn.assert_not_called()
@@ -392,15 +384,15 @@ def test_get_no_location(self, mock_get_conn):
self.client.delete_location(LocationArn=location["LocationArn"])
locations = self.client.list_locations()
- self.assertEqual(len(locations["Locations"]), 0)
+ assert len(locations["Locations"]) == 0
# Execute the task
result = self.datasync.execute(None)
- self.assertIsNotNone(result)
+ assert result is not None
locations = self.client.list_locations()
- self.assertIsNotNone(result)
- self.assertEqual(len(locations), 2)
+ assert result is not None
+ assert len(locations) == 2
# ### Check mocks:
mock_get_conn.assert_called()
@@ -415,11 +407,11 @@ def test_get_no_tasks2(self, mock_get_conn):
self.client.delete_task(TaskArn=task["TaskArn"])
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks["Tasks"]), 0)
+ assert len(tasks["Tasks"]) == 0
# Execute the task
result = self.datasync.execute(None)
- self.assertIsNotNone(result)
+ assert result is not None
# ### Check mocks:
mock_get_conn.assert_called()
@@ -430,28 +422,28 @@ def test_get_one_task(self, mock_get_conn):
# Make sure we don't cheat
self.set_up_operator()
- self.assertEqual(self.datasync.task_arn, None)
+ assert self.datasync.task_arn is None
# Check how many tasks and locations we have
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks["Tasks"]), 1)
+ assert len(tasks["Tasks"]) == 1
locations = self.client.list_locations()
- self.assertEqual(len(locations["Locations"]), 2)
+ assert len(locations["Locations"]) == 2
# Execute the task
result = self.datasync.execute(None)
- self.assertIsNotNone(result)
+ assert result is not None
task_arn = result["TaskArn"]
- self.assertIsNotNone(task_arn)
- self.assertTrue(task_arn)
- self.assertEqual(task_arn, self.task_arn)
+ assert task_arn is not None
+ assert task_arn
+ assert task_arn == self.task_arn
# Assert 0 additional task and 0 additional locations
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks["Tasks"]), 1)
+ assert len(tasks["Tasks"]) == 1
locations = self.client.list_locations()
- self.assertEqual(len(locations["Locations"]), 2)
+ assert len(locations["Locations"]) == 2
# ### Check mocks:
mock_get_conn.assert_called()
@@ -469,19 +461,19 @@ def test_get_many_tasks(self, mock_get_conn):
# Check how many tasks and locations we have
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks["Tasks"]), 2)
+ assert len(tasks["Tasks"]) == 2
locations = self.client.list_locations()
- self.assertEqual(len(locations["Locations"]), 2)
+ assert len(locations["Locations"]) == 2
# Execute the task
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.datasync.execute(None)
# Assert 0 additional task and 0 additional locations
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks["Tasks"]), 2)
+ assert len(tasks["Tasks"]) == 2
locations = self.client.list_locations()
- self.assertEqual(len(locations["Locations"]), 2)
+ assert len(locations["Locations"]) == 2
self.set_up_operator(task_id='datasync_task2', task_arn=self.task_arn, allow_random_task_choice=True)
self.datasync.execute(None)
@@ -500,8 +492,8 @@ def test_execute_specific_task(self, mock_get_conn):
self.set_up_operator(task_arn=task_arn)
result = self.datasync.execute(None)
- self.assertEqual(result["TaskArn"], task_arn)
- self.assertEqual(self.datasync.task_arn, task_arn)
+ assert result["TaskArn"] == task_arn
+ assert self.datasync.task_arn == task_arn
# ### Check mocks:
mock_get_conn.assert_called()
@@ -514,7 +506,7 @@ def test_xcom_push(self, mock_get_conn):
ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow())
ti.run()
pushed_task_arn = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value")["TaskArn"]
- self.assertEqual(pushed_task_arn, self.task_arn)
+ assert pushed_task_arn == self.task_arn
# ### Check mocks:
mock_get_conn.assert_called()
@@ -545,12 +537,12 @@ def set_up_operator(
def test_init(self, mock_get_conn):
self.set_up_operator()
# Airflow built-ins
- self.assertEqual(self.datasync.task_id, MOCK_DATA["update_task_id"])
+ assert self.datasync.task_id == MOCK_DATA["update_task_id"]
# Defaults
- self.assertEqual(self.datasync.aws_conn_id, "aws_default")
+ assert self.datasync.aws_conn_id == "aws_default"
# Assignments
- self.assertEqual(self.datasync.task_arn, self.task_arn)
- self.assertEqual(self.datasync.update_task_kwargs, MOCK_DATA["update_task_kwargs"])
+ assert self.datasync.task_arn == self.task_arn
+ assert self.datasync.update_task_kwargs == MOCK_DATA["update_task_kwargs"]
# ### Check mocks:
mock_get_conn.assert_not_called()
@@ -559,7 +551,7 @@ def test_init_fails(self, mock_get_conn):
mock_get_conn.return_value = self.client
# ### Begin tests:
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.set_up_operator(task_arn=None)
# ### Check mocks:
mock_get_conn.assert_not_called()
@@ -573,18 +565,18 @@ def test_update_task(self, mock_get_conn):
# Check task before update
task = self.client.describe_task(TaskArn=self.task_arn)
- self.assertNotIn("Options", task)
+ assert "Options" not in task
# Execute the task
result = self.datasync.execute(None)
- self.assertIsNotNone(result)
- self.assertEqual(result["TaskArn"], self.task_arn)
+ assert result is not None
+ assert result["TaskArn"] == self.task_arn
- self.assertIsNotNone(self.datasync.task_arn)
+ assert self.datasync.task_arn is not None
# Check it was updated
task = self.client.describe_task(TaskArn=self.task_arn)
- self.assertEqual(task["Options"], UPDATE_TASK_KWARGS["Options"])
+ assert task["Options"] == UPDATE_TASK_KWARGS["Options"]
# ### Check mocks:
mock_get_conn.assert_called()
@@ -600,8 +592,8 @@ def test_execute_specific_task(self, mock_get_conn):
self.set_up_operator(task_arn=task_arn)
result = self.datasync.execute(None)
- self.assertEqual(result["TaskArn"], task_arn)
- self.assertEqual(self.datasync.task_arn, task_arn)
+ assert result["TaskArn"] == task_arn
+ assert self.datasync.task_arn == task_arn
# ### Check mocks:
mock_get_conn.assert_called()
@@ -614,7 +606,7 @@ def test_xcom_push(self, mock_get_conn):
ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow())
ti.run()
pushed_task_arn = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value")["TaskArn"]
- self.assertEqual(pushed_task_arn, self.task_arn)
+ assert pushed_task_arn == self.task_arn
# ### Check mocks:
mock_get_conn.assert_called()
@@ -636,12 +628,12 @@ def set_up_operator(self, task_id="test_aws_datasync_task_operator", task_arn="s
def test_init(self, mock_get_conn):
self.set_up_operator()
# Airflow built-ins
- self.assertEqual(self.datasync.task_id, MOCK_DATA["task_id"])
+ assert self.datasync.task_id == MOCK_DATA["task_id"]
# Defaults
- self.assertEqual(self.datasync.aws_conn_id, "aws_default")
- self.assertEqual(self.datasync.wait_interval_seconds, 0)
+ assert self.datasync.aws_conn_id == "aws_default"
+ assert self.datasync.wait_interval_seconds == 0
# Assignments
- self.assertEqual(self.datasync.task_arn, self.task_arn)
+ assert self.datasync.task_arn == self.task_arn
# ### Check mocks:
mock_get_conn.assert_not_called()
@@ -650,7 +642,7 @@ def test_init_fails(self, mock_get_conn):
mock_get_conn.return_value = self.client
# ### Begin tests:
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.set_up_operator(task_arn=None)
# ### Check mocks:
mock_get_conn.assert_not_called()
@@ -662,7 +654,7 @@ def test_execute_task(self, mock_get_conn):
# Configure the Operator with the specific task_arn
self.set_up_operator()
- self.assertEqual(self.datasync.task_arn, self.task_arn)
+ assert self.datasync.task_arn == self.task_arn
# Check how many tasks and locations we have
tasks = self.client.list_tasks()
@@ -672,19 +664,19 @@ def test_execute_task(self, mock_get_conn):
# Execute the task
result = self.datasync.execute(None)
- self.assertIsNotNone(result)
+ assert result is not None
task_execution_arn = result["TaskExecutionArn"]
- self.assertIsNotNone(task_execution_arn)
+ assert task_execution_arn is not None
# Assert 0 additional task and 0 additional locations
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks["Tasks"]), len_tasks_before)
+ assert len(tasks["Tasks"]) == len_tasks_before
locations = self.client.list_locations()
- self.assertEqual(len(locations["Locations"]), len_locations_before)
+ assert len(locations["Locations"]) == len_locations_before
# Check with the DataSync client what happened
task_execution = self.client.describe_task_execution(TaskExecutionArn=task_execution_arn)
- self.assertEqual(task_execution["Status"], "SUCCESS")
+ assert task_execution["Status"] == "SUCCESS"
# Insist that this specific task was executed, not anything else
task_execution_arn = task_execution["TaskExecutionArn"]
@@ -692,7 +684,7 @@ def test_execute_task(self, mock_get_conn):
# arn:aws:datasync:us-east-1:111222333444:task/task-00000000000000003/execution/exec-00000000000000004
# format of task_arn:
# arn:aws:datasync:us-east-1:111222333444:task/task-00000000000000003
- self.assertEqual("/".join(task_execution_arn.split("/")[:2]), self.task_arn)
+ assert "/".join(task_execution_arn.split("/")[:2]) == self.task_arn
# ### Check mocks:
mock_get_conn.assert_called()
@@ -706,7 +698,7 @@ def test_failed_task(self, mock_wait, mock_get_conn):
self.set_up_operator()
# Execute the task
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.datasync.execute(None)
# ### Check mocks:
mock_get_conn.assert_called()
@@ -728,16 +720,16 @@ def kill_task(*args):
# Execute the task
result = self.datasync.execute(None)
- self.assertIsNotNone(result)
+ assert result is not None
task_execution_arn = result["TaskExecutionArn"]
- self.assertIsNotNone(task_execution_arn)
+ assert task_execution_arn is not None
# Verify the task was killed
task = self.client.describe_task(TaskArn=self.task_arn)
- self.assertEqual(task["Status"], "AVAILABLE")
+ assert task["Status"] == "AVAILABLE"
task_execution = self.client.describe_task_execution(TaskExecutionArn=task_execution_arn)
- self.assertEqual(task_execution["Status"], "ERROR")
+ assert task_execution["Status"] == "ERROR"
# ### Check mocks:
mock_get_conn.assert_called()
@@ -753,8 +745,8 @@ def test_execute_specific_task(self, mock_get_conn):
self.set_up_operator(task_arn=task_arn)
result = self.datasync.execute(None)
- self.assertEqual(result["TaskArn"], task_arn)
- self.assertEqual(self.datasync.task_arn, task_arn)
+ assert result["TaskArn"] == task_arn
+ assert self.datasync.task_arn == task_arn
# ### Check mocks:
mock_get_conn.assert_called()
@@ -767,7 +759,7 @@ def test_xcom_push(self, mock_get_conn):
ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow())
ti.run()
xcom_result = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value")
- self.assertIsNotNone(xcom_result)
+ assert xcom_result is not None
# ### Check mocks:
mock_get_conn.assert_called()
@@ -790,11 +782,11 @@ def set_up_operator(self, task_id="test_aws_datasync_delete_task_operator", task
def test_init(self, mock_get_conn):
self.set_up_operator()
# Airflow built-ins
- self.assertEqual(self.datasync.task_id, MOCK_DATA["delete_task_id"])
+ assert self.datasync.task_id == MOCK_DATA["delete_task_id"]
# Defaults
- self.assertEqual(self.datasync.aws_conn_id, "aws_default")
+ assert self.datasync.aws_conn_id == "aws_default"
# Assignments
- self.assertEqual(self.datasync.task_arn, self.task_arn)
+ assert self.datasync.task_arn == self.task_arn
# ### Check mocks:
mock_get_conn.assert_not_called()
@@ -803,7 +795,7 @@ def test_init_fails(self, mock_get_conn):
mock_get_conn.return_value = self.client
# ### Begin tests:
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.set_up_operator(task_arn=None)
# ### Check mocks:
mock_get_conn.assert_not_called()
@@ -817,20 +809,20 @@ def test_delete_task(self, mock_get_conn):
# Check how many tasks and locations we have
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks["Tasks"]), 1)
+ assert len(tasks["Tasks"]) == 1
locations = self.client.list_locations()
- self.assertEqual(len(locations["Locations"]), 2)
+ assert len(locations["Locations"]) == 2
# Execute the task
result = self.datasync.execute(None)
- self.assertIsNotNone(result)
- self.assertEqual(result["TaskArn"], self.task_arn)
+ assert result is not None
+ assert result["TaskArn"] == self.task_arn
# Assert -1 additional task and 0 additional locations
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks["Tasks"]), 0)
+ assert len(tasks["Tasks"]) == 0
locations = self.client.list_locations()
- self.assertEqual(len(locations["Locations"]), 2)
+ assert len(locations["Locations"]) == 2
# ### Check mocks:
mock_get_conn.assert_called()
@@ -846,8 +838,8 @@ def test_execute_specific_task(self, mock_get_conn):
self.set_up_operator(task_arn=task_arn)
result = self.datasync.execute(None)
- self.assertEqual(result["TaskArn"], task_arn)
- self.assertEqual(self.datasync.task_arn, task_arn)
+ assert result["TaskArn"] == task_arn
+ assert self.datasync.task_arn == task_arn
# ### Check mocks:
mock_get_conn.assert_called()
@@ -860,6 +852,6 @@ def test_xcom_push(self, mock_get_conn):
ti = TaskInstance(task=self.datasync, execution_date=timezone.utcnow())
ti.run()
pushed_task_arn = ti.xcom_pull(task_ids=self.datasync.task_id, key="return_value")["TaskArn"]
- self.assertEqual(pushed_task_arn, self.task_arn)
+ assert pushed_task_arn == self.task_arn
# ### Check mocks:
mock_get_conn.assert_called()
diff --git a/tests/providers/amazon/aws/operators/test_ec2_start_instance.py b/tests/providers/amazon/aws/operators/test_ec2_start_instance.py
index bd598d8970b72..994cf90769e46 100644
--- a/tests/providers/amazon/aws/operators/test_ec2_start_instance.py
+++ b/tests/providers/amazon/aws/operators/test_ec2_start_instance.py
@@ -34,11 +34,11 @@ def test_init(self):
region_name="region-test",
check_interval=3,
)
- self.assertEqual(ec2_operator.task_id, "task_test")
- self.assertEqual(ec2_operator.instance_id, "i-123abc")
- self.assertEqual(ec2_operator.aws_conn_id, "aws_conn_test")
- self.assertEqual(ec2_operator.region_name, "region-test")
- self.assertEqual(ec2_operator.check_interval, 3)
+ assert ec2_operator.task_id == "task_test"
+ assert ec2_operator.instance_id == "i-123abc"
+ assert ec2_operator.aws_conn_id == "aws_conn_test"
+ assert ec2_operator.region_name == "region-test"
+ assert ec2_operator.check_interval == 3
@mock_ec2
def test_start_instance(self):
@@ -57,4 +57,4 @@ def test_start_instance(self):
)
start_test.execute(None)
# assert instance state is running
- self.assertEqual(ec2_hook.get_instance_state(instance_id=instance_id), "running")
+ assert ec2_hook.get_instance_state(instance_id=instance_id) == "running"
diff --git a/tests/providers/amazon/aws/operators/test_ec2_stop_instance.py b/tests/providers/amazon/aws/operators/test_ec2_stop_instance.py
index e8d4d283961ac..6bc591b1eaab6 100644
--- a/tests/providers/amazon/aws/operators/test_ec2_stop_instance.py
+++ b/tests/providers/amazon/aws/operators/test_ec2_stop_instance.py
@@ -34,11 +34,11 @@ def test_init(self):
region_name="region-test",
check_interval=3,
)
- self.assertEqual(ec2_operator.task_id, "task_test")
- self.assertEqual(ec2_operator.instance_id, "i-123abc")
- self.assertEqual(ec2_operator.aws_conn_id, "aws_conn_test")
- self.assertEqual(ec2_operator.region_name, "region-test")
- self.assertEqual(ec2_operator.check_interval, 3)
+ assert ec2_operator.task_id == "task_test"
+ assert ec2_operator.instance_id == "i-123abc"
+ assert ec2_operator.aws_conn_id == "aws_conn_test"
+ assert ec2_operator.region_name == "region-test"
+ assert ec2_operator.check_interval == 3
@mock_ec2
def test_stop_instance(self):
@@ -57,4 +57,4 @@ def test_stop_instance(self):
)
stop_test.execute(None)
# assert instance state is running
- self.assertEqual(ec2_hook.get_instance_state(instance_id=instance_id), "stopped")
+ assert ec2_hook.get_instance_state(instance_id=instance_id) == "stopped"
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py
index 850c4cf9b1a00..7465f0c089772 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -22,6 +22,7 @@
from copy import deepcopy
from unittest import mock
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -80,17 +81,17 @@ def setUp(self):
self.set_up_operator() # pylint: disable=no-value-for-parameter
def test_init(self):
- self.assertEqual(self.ecs.region_name, 'eu-west-1')
- self.assertEqual(self.ecs.task_definition, 't')
- self.assertEqual(self.ecs.aws_conn_id, None)
- self.assertEqual(self.ecs.cluster, 'c')
- self.assertEqual(self.ecs.overrides, {})
+ assert self.ecs.region_name == 'eu-west-1'
+ assert self.ecs.task_definition == 't'
+ assert self.ecs.aws_conn_id is None
+ assert self.ecs.cluster == 'c'
+ assert self.ecs.overrides == {}
self.ecs.get_hook()
- self.assertEqual(self.ecs.hook, self.aws_hook_mock.return_value)
+ assert self.ecs.hook == self.aws_hook_mock.return_value
self.aws_hook_mock.assert_called_once()
def test_template_fields_overrides(self):
- self.assertEqual(self.ecs.template_fields, ('overrides',))
+ assert self.ecs.template_fields == ('overrides',)
@parameterized.expand(
[
@@ -136,9 +137,7 @@ def test_execute_without_failures(self, launch_type, tags, check_mock, wait_mock
wait_mock.assert_called_once_with()
check_mock.assert_called_once_with()
- self.assertEqual(
- self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55'
- )
+ assert self.ecs.arn == 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55'
def test_execute_with_failures(self):
client_mock = self.aws_hook_mock.return_value.get_conn.return_value
@@ -146,7 +145,7 @@ def test_execute_with_failures(self):
resp_failures['failures'].append('dummy error')
client_mock.run_task.return_value = resp_failures
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.ecs.execute(None)
self.aws_hook_mock.return_value.get_conn.assert_called_once()
@@ -176,7 +175,7 @@ def test_wait_end_tasks(self):
self.ecs._wait_for_task_ended()
client_mock.get_waiter.assert_called_once_with('tasks_stopped')
client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster='c', tasks=['arn'])
- self.assertEqual(sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts)
+ assert sys.maxsize == client_mock.get_waiter.return_value.config.max_attempts
def test_check_success_tasks_raises(self):
client_mock = mock.Mock()
@@ -186,14 +185,14 @@ def test_check_success_tasks_raises(self):
client_mock.describe_tasks.return_value = {
'tasks': [{'containers': [{'name': 'foo', 'lastStatus': 'STOPPED', 'exitCode': 1}]}]
}
- with self.assertRaises(Exception) as e:
+ with pytest.raises(Exception) as ctx:
self.ecs._check_success_task()
# Ordering of str(dict) is not guaranteed.
- self.assertIn("This task is not in success state ", str(e.exception))
- self.assertIn("'name': 'foo'", str(e.exception))
- self.assertIn("'lastStatus': 'STOPPED'", str(e.exception))
- self.assertIn("'exitCode': 1", str(e.exception))
+ assert "This task is not in success state " in str(ctx.value)
+ assert "'name': 'foo'" in str(ctx.value)
+ assert "'lastStatus': 'STOPPED'" in str(ctx.value)
+ assert "'exitCode': 1" in str(ctx.value)
client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
def test_check_success_tasks_raises_pending(self):
@@ -203,12 +202,12 @@ def test_check_success_tasks_raises_pending(self):
client_mock.describe_tasks.return_value = {
'tasks': [{'containers': [{'name': 'container-name', 'lastStatus': 'PENDING'}]}]
}
- with self.assertRaises(Exception) as e:
+ with pytest.raises(Exception) as ctx:
self.ecs._check_success_task()
# Ordering of str(dict) is not guaranteed.
- self.assertIn("This task is still pending ", str(e.exception))
- self.assertIn("'name': 'container-name'", str(e.exception))
- self.assertIn("'lastStatus': 'PENDING'", str(e.exception))
+ assert "This task is still pending " in str(ctx.value)
+ assert "'name': 'container-name'" in str(ctx.value)
+ assert "'lastStatus': 'PENDING'" in str(ctx.value)
client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
def test_check_success_tasks_raises_multiple(self):
@@ -252,12 +251,12 @@ def test_host_terminated_raises(self):
]
}
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.ecs._check_success_task()
- self.assertIn("The task was stopped because the host instance terminated:", str(e.exception))
- self.assertIn("Host EC2 (", str(e.exception))
- self.assertIn(") terminated", str(e.exception))
+ assert "The task was stopped because the host instance terminated:" in str(ctx.value)
+ assert "Host EC2 (" in str(ctx.value)
+ assert ") terminated" in str(ctx.value)
client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
def test_check_success_task_not_raises(self):
@@ -311,21 +310,19 @@ def test_reattach_successful(self, launch_type, tags, start_mock, check_mock, wa
start_mock.assert_not_called()
wait_mock.assert_called_once_with()
check_mock.assert_called_once_with()
- self.assertEqual(
- self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55'
- )
+ assert self.ecs.arn == 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55'
@mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output")
def test_execute_xcom_with_log(self, mock_cloudwatch_log_message):
self.ecs.do_xcom_push = True
- self.assertEqual(self.ecs.execute(None), mock_cloudwatch_log_message.return_value)
+ assert self.ecs.execute(None) == mock_cloudwatch_log_message.return_value
@mock.patch.object(ECSOperator, '_last_log_message', return_value=None)
def test_execute_xcom_with_no_log(self, mock_cloudwatch_log_message):
self.ecs.do_xcom_push = True
- self.assertEqual(self.ecs.execute(None), mock_cloudwatch_log_message.return_value)
+ assert self.ecs.execute(None) == mock_cloudwatch_log_message.return_value
@mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output")
def test_execute_xcom_disabled(self, mock_cloudwatch_log_message):
self.ecs.do_xcom_push = False
- self.assertEqual(self.ecs.execute(None), None)
+ assert self.ecs.execute(None) is None
diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py
index fe68f059765dd..77aef3e314159 100644
--- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py
+++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py
@@ -21,6 +21,7 @@
from datetime import timedelta
from unittest.mock import MagicMock, patch
+import pytest
from jinja2 import StrictUndefined
from airflow.exceptions import AirflowException
@@ -75,8 +76,8 @@ def setUp(self):
)
def test_init(self):
- self.assertEqual(self.operator.job_flow_id, 'j-8989898989')
- self.assertEqual(self.operator.aws_conn_id, 'aws_default')
+ assert self.operator.job_flow_id == 'j-8989898989'
+ assert self.operator.aws_conn_id == 'aws_default'
def test_render_template(self):
ti = TaskInstance(self.operator, DEFAULT_DATE)
@@ -97,7 +98,7 @@ def test_render_template(self):
}
]
- self.assertListEqual(self.operator.steps, expected_args)
+ assert self.operator.steps == expected_args
def test_render_template_2(self):
dag = DAG(dag_id='test_xcom', default_args=self.args)
@@ -178,7 +179,7 @@ def test_execute_returns_step_id(self):
self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
with patch('boto3.session.Session', self.boto3_session_mock):
- self.assertEqual(self.operator.execute(self.mock_context), ['s-2LH3R5GW3A53T'])
+ assert self.operator.execute(self.mock_context) == ['s-2LH3R5GW3A53T']
def test_init_with_cluster_name(self):
expected_job_flow_id = 'j-1231231234'
@@ -221,6 +222,6 @@ def test_init_with_nonexistent_cluster_name(self):
dag=DAG('test_dag_id', default_args=self.args),
)
- with self.assertRaises(AirflowException) as error:
+ with pytest.raises(AirflowException) as ctx:
operator.execute(self.mock_context)
- self.assertEqual(str(error.exception), f'No cluster found for name: {cluster_name}')
+ assert str(ctx.value) == f'No cluster found for name: {cluster_name}'
diff --git a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
index eb6883d220c5e..7e8ad2c72b211 100644
--- a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
+++ b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
@@ -75,9 +75,9 @@ def setUp(self):
)
def test_init(self):
- self.assertEqual(self.operator.aws_conn_id, 'aws_default')
- self.assertEqual(self.operator.emr_conn_id, 'emr_default')
- self.assertEqual(self.operator.region_name, 'ap-southeast-2')
+ assert self.operator.aws_conn_id == 'aws_default'
+ assert self.operator.emr_conn_id == 'emr_default'
+ assert self.operator.region_name == 'ap-southeast-2'
def test_render_template(self):
self.operator.job_flow_overrides = self._config
@@ -103,7 +103,7 @@ def test_render_template(self):
],
}
- self.assertDictEqual(self.operator.job_flow_overrides, expected_args)
+ assert self.operator.job_flow_overrides == expected_args
def test_render_template_from_file(self):
self.operator.job_flow_overrides = 'job.j2.json'
@@ -139,7 +139,7 @@ def test_render_template_from_file(self):
],
}
- self.assertDictEqual(self.operator.job_flow_overrides, expected_args)
+ assert self.operator.job_flow_overrides == expected_args
def test_execute_returns_job_id(self):
self.emr_client_mock.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
@@ -150,4 +150,4 @@ def test_execute_returns_job_id(self):
boto3_session_mock = MagicMock(return_value=emr_session_mock)
with patch('boto3.session.Session', boto3_session_mock):
- self.assertEqual(self.operator.execute(None), 'j-8989898989')
+ assert self.operator.execute(None) == 'j-8989898989'
diff --git a/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py b/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py
index ef284ef756963..86feed28a29e9 100644
--- a/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py
+++ b/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py
@@ -19,6 +19,8 @@
import unittest
from unittest.mock import MagicMock, patch
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.operators.emr_modify_cluster import EmrModifyClusterOperator
@@ -54,18 +56,19 @@ def setUp(self):
)
def test_init(self):
- self.assertEqual(self.operator.cluster_id, 'j-8989898989')
- self.assertEqual(self.operator.step_concurrency_level, 1)
- self.assertEqual(self.operator.aws_conn_id, 'aws_default')
+ assert self.operator.cluster_id == 'j-8989898989'
+ assert self.operator.step_concurrency_level == 1
+ assert self.operator.aws_conn_id == 'aws_default'
def test_execute_returns_step_concurrency(self):
self.emr_client_mock.modify_cluster.return_value = MODIFY_CLUSTER_SUCCESS_RETURN
with patch('boto3.session.Session', self.boto3_session_mock):
- self.assertEqual(self.operator.execute(self.mock_context), 1)
+ assert self.operator.execute(self.mock_context) == 1
def test_execute_returns_error(self):
self.emr_client_mock.modify_cluster.return_value = MODIFY_CLUSTER_ERROR_RETURN
with patch('boto3.session.Session', self.boto3_session_mock):
- self.assertRaises(AirflowException, self.operator.execute, self.mock_context)
+ with pytest.raises(AirflowException):
+ self.operator.execute(self.mock_context)
diff --git a/tests/providers/amazon/aws/operators/test_glue.py b/tests/providers/amazon/aws/operators/test_glue.py
index 144bebed29f83..5896e841cff3d 100644
--- a/tests/providers/amazon/aws/operators/test_glue.py
+++ b/tests/providers/amazon/aws/operators/test_glue.py
@@ -53,4 +53,4 @@ def test_execute_without_failure(
self.glue.execute(None)
mock_initialize_job.assert_called_once_with({})
- self.assertEqual(self.glue.job_name, 'my_test_job')
+ assert self.glue.job_name == 'my_test_job'
diff --git a/tests/providers/amazon/aws/operators/test_s3_copy_object.py b/tests/providers/amazon/aws/operators/test_s3_copy_object.py
index c9f6af31e335a..ac810afa5d095 100644
--- a/tests/providers/amazon/aws/operators/test_s3_copy_object.py
+++ b/tests/providers/amazon/aws/operators/test_s3_copy_object.py
@@ -40,7 +40,7 @@ def test_s3_copy_object_arg_combination_1(self):
conn.upload_fileobj(Bucket=self.source_bucket, Key=self.source_key, Fileobj=io.BytesIO(b"input"))
# there should be nothing found before S3CopyObjectOperator is executed
- self.assertFalse('Contents' in conn.list_objects(Bucket=self.dest_bucket, Prefix=self.dest_key))
+ assert 'Contents' not in conn.list_objects(Bucket=self.dest_bucket, Prefix=self.dest_key)
op = S3CopyObjectOperator(
task_id="test_task_s3_copy_object",
@@ -53,9 +53,9 @@ def test_s3_copy_object_arg_combination_1(self):
objects_in_dest_bucket = conn.list_objects(Bucket=self.dest_bucket, Prefix=self.dest_key)
# there should be object found, and there should only be one object found
- self.assertEqual(len(objects_in_dest_bucket['Contents']), 1)
+ assert len(objects_in_dest_bucket['Contents']) == 1
# the object found should be consistent with dest_key specified earlier
- self.assertEqual(objects_in_dest_bucket['Contents'][0]['Key'], self.dest_key)
+ assert objects_in_dest_bucket['Contents'][0]['Key'] == self.dest_key
@mock_s3
def test_s3_copy_object_arg_combination_2(self):
@@ -65,7 +65,7 @@ def test_s3_copy_object_arg_combination_2(self):
conn.upload_fileobj(Bucket=self.source_bucket, Key=self.source_key, Fileobj=io.BytesIO(b"input"))
# there should be nothing found before S3CopyObjectOperator is executed
- self.assertFalse('Contents' in conn.list_objects(Bucket=self.dest_bucket, Prefix=self.dest_key))
+ assert 'Contents' not in conn.list_objects(Bucket=self.dest_bucket, Prefix=self.dest_key)
source_key_s3_url = f"s3://{self.source_bucket}/{self.source_key}"
dest_key_s3_url = f"s3://{self.dest_bucket}/{self.dest_key}"
@@ -78,6 +78,6 @@ def test_s3_copy_object_arg_combination_2(self):
objects_in_dest_bucket = conn.list_objects(Bucket=self.dest_bucket, Prefix=self.dest_key)
# there should be object found, and there should only be one object found
- self.assertEqual(len(objects_in_dest_bucket['Contents']), 1)
+ assert len(objects_in_dest_bucket['Contents']) == 1
# the object found should be consistent with dest_key specified earlier
- self.assertEqual(objects_in_dest_bucket['Contents'][0]['Key'], self.dest_key)
+ assert objects_in_dest_bucket['Contents'][0]['Key'] == self.dest_key
diff --git a/tests/providers/amazon/aws/operators/test_s3_delete_objects.py b/tests/providers/amazon/aws/operators/test_s3_delete_objects.py
index 5d7821c4d043f..d134da4e9c596 100644
--- a/tests/providers/amazon/aws/operators/test_s3_delete_objects.py
+++ b/tests/providers/amazon/aws/operators/test_s3_delete_objects.py
@@ -37,14 +37,14 @@ def test_s3_delete_single_object(self):
# The object should be detected before the DELETE action is taken
objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=key)
- self.assertEqual(len(objects_in_dest_bucket['Contents']), 1)
- self.assertEqual(objects_in_dest_bucket['Contents'][0]['Key'], key)
+ assert len(objects_in_dest_bucket['Contents']) == 1
+ assert objects_in_dest_bucket['Contents'][0]['Key'] == key
op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object", bucket=bucket, keys=key)
op.execute(None)
# There should be no object found in the bucket created earlier
- self.assertFalse('Contents' in conn.list_objects(Bucket=bucket, Prefix=key))
+ assert 'Contents' not in conn.list_objects(Bucket=bucket, Prefix=key)
@mock_s3
def test_s3_delete_multiple_objects(self):
@@ -60,14 +60,14 @@ def test_s3_delete_multiple_objects(self):
# The objects should be detected before the DELETE action is taken
objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=key_pattern)
- self.assertEqual(len(objects_in_dest_bucket['Contents']), n_keys)
- self.assertEqual(sorted([x['Key'] for x in objects_in_dest_bucket['Contents']]), sorted(keys))
+ assert len(objects_in_dest_bucket['Contents']) == n_keys
+ assert sorted([x['Key'] for x in objects_in_dest_bucket['Contents']]) == sorted(keys)
op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_multiple_objects", bucket=bucket, keys=keys)
op.execute(None)
# There should be no object found in the bucket created earlier
- self.assertFalse('Contents' in conn.list_objects(Bucket=bucket, Prefix=key_pattern))
+ assert 'Contents' not in conn.list_objects(Bucket=bucket, Prefix=key_pattern)
@mock_s3
def test_s3_delete_prefix(self):
@@ -83,11 +83,11 @@ def test_s3_delete_prefix(self):
# The objects should be detected before the DELETE action is taken
objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=key_pattern)
- self.assertEqual(len(objects_in_dest_bucket['Contents']), n_keys)
- self.assertEqual(sorted([x['Key'] for x in objects_in_dest_bucket['Contents']]), sorted(keys))
+ assert len(objects_in_dest_bucket['Contents']) == n_keys
+ assert sorted([x['Key'] for x in objects_in_dest_bucket['Contents']]) == sorted(keys)
op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_prefix", bucket=bucket, prefix=key_pattern)
op.execute(None)
# There should be no object found in the bucket created earlier
- self.assertFalse('Contents' in conn.list_objects(Bucket=bucket, Prefix=key_pattern))
+ assert 'Contents' not in conn.list_objects(Bucket=bucket, Prefix=key_pattern)
diff --git a/tests/providers/amazon/aws/operators/test_s3_file_transform.py b/tests/providers/amazon/aws/operators/test_s3_file_transform.py
index 8394a5078570c..a5df0e3e9c0a0 100644
--- a/tests/providers/amazon/aws/operators/test_s3_file_transform.py
+++ b/tests/providers/amazon/aws/operators/test_s3_file_transform.py
@@ -27,6 +27,7 @@
from unittest import mock
import boto3
+import pytest
from moto import mock_s3
from airflow.exceptions import AirflowException
@@ -87,10 +88,10 @@ def test_execute_with_failing_transform_script(self, mock_popen):
task_id="task_id",
)
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
op.execute(None)
- self.assertEqual('Transform script failed: 42', str(e.exception))
+ assert 'Transform script failed: 42' == str(ctx.value)
@mock.patch('subprocess.Popen')
@mock_s3
@@ -109,7 +110,7 @@ def test_execute_with_transform_script_args(self, mock_popen):
)
op.execute(None)
- self.assertEqual(script_args, mock_popen.call_args[0][0][3:])
+ assert script_args == mock_popen.call_args[0][0][3:]
@mock.patch('airflow.providers.amazon.aws.hooks.s3.S3Hook.select_key', return_value="input")
@mock_s3
@@ -130,7 +131,7 @@ def test_execute_with_select_expression(self, mock_select_key):
conn = boto3.client('s3')
result = conn.get_object(Bucket=self.bucket, Key=self.output_key)
- self.assertEqual(self.content, result['Body'].read())
+ assert self.content == result['Body'].read()
@staticmethod
def mock_process(mock_popen, return_code=0, process_output=None):
diff --git a/tests/providers/amazon/aws/operators/test_s3_list.py b/tests/providers/amazon/aws/operators/test_s3_list.py
index b51c8fac89b65..249b97140651b 100644
--- a/tests/providers/amazon/aws/operators/test_s3_list.py
+++ b/tests/providers/amazon/aws/operators/test_s3_list.py
@@ -41,4 +41,4 @@ def test_execute(self, mock_hook):
mock_hook.return_value.list_keys.assert_called_once_with(
bucket_name=BUCKET, prefix=PREFIX, delimiter=DELIMITER
)
- self.assertEqual(sorted(files), sorted(MOCK_FILES))
+ assert sorted(files) == sorted(MOCK_FILES)
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_base.py b/tests/providers/amazon/aws/operators/test_sagemaker_base.py
index 57f3e36a2a4e3..6b128da7e5fc2 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_base.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_base.py
@@ -33,4 +33,4 @@ def setUp(self):
def test_parse_integer(self):
self.sagemaker.integer_fields = [['key1'], ['key2', 'key3'], ['key2', 'key4'], ['key5', 'key6']]
self.sagemaker.parse_config_integers()
- self.assertEqual(self.sagemaker.config, parsed_config)
+ assert self.sagemaker.config == parsed_config
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
index d08c92453842c..9c68ad45715f9 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
@@ -19,6 +19,7 @@
import unittest
from unittest import mock
+import pytest
from botocore.exceptions import ClientError
from airflow.exceptions import AirflowException
@@ -77,7 +78,7 @@ def setUp(self):
def test_parse_config_integers(self):
self.sagemaker.parse_config_integers()
for variant in self.sagemaker.config['EndpointConfig']['ProductionVariants']:
- self.assertEqual(variant['InitialInstanceCount'], int(variant['InitialInstanceCount']))
+ assert variant['InitialInstanceCount'] == int(variant['InitialInstanceCount'])
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_model')
@@ -98,7 +99,8 @@ def test_execute(self, mock_endpoint, mock_endpoint_config, mock_model, mock_cli
@mock.patch.object(SageMakerHook, 'create_endpoint')
def test_execute_with_failure(self, mock_endpoint, mock_endpoint_config, mock_model, mock_client):
mock_endpoint.return_value = {'EndpointArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
- self.assertRaises(AirflowException, self.sagemaker.execute, None)
+ with pytest.raises(AirflowException):
+ self.sagemaker.execute(None)
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_model')
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
index e725c0228a706..b8bf18c88039a 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker_endpoint_config import SageMakerEndpointConfigOperator
@@ -50,7 +52,7 @@ def setUp(self):
def test_parse_config_integers(self):
self.sagemaker.parse_config_integers()
for variant in self.sagemaker.config['ProductionVariants']:
- self.assertEqual(variant['InitialInstanceCount'], int(variant['InitialInstanceCount']))
+ assert variant['InitialInstanceCount'] == int(variant['InitialInstanceCount'])
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_endpoint_config')
@@ -69,4 +71,5 @@ def test_execute_with_failure(self, mock_model, mock_client):
'EndpointConfigArn': 'testarn',
'ResponseMetadata': {'HTTPStatusCode': 200},
}
- self.assertRaises(AirflowException, self.sagemaker.execute, None)
+ with pytest.raises(AirflowException):
+ self.sagemaker.execute(None)
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_model.py b/tests/providers/amazon/aws/operators/test_sagemaker_model.py
index 6676f1009c395..17ba3d91fb8e2 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_model.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_model.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker_model import SageMakerModelOperator
@@ -59,4 +61,5 @@ def test_execute(self, mock_model, mock_client):
@mock.patch.object(SageMakerHook, 'create_model')
def test_execute_with_failure(self, mock_model, mock_client):
mock_model.return_value = {'ModelArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
- self.assertRaises(AirflowException, self.sagemaker.execute, None)
+ with pytest.raises(AirflowException):
+ self.sagemaker.execute(None)
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
index 5e55b39006167..6ceb286b25289 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
@@ -18,6 +18,7 @@
import unittest
from unittest import mock
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -138,7 +139,8 @@ def test_execute_with_failure(self, mock_processing, mock_client):
sagemaker = SageMakerProcessingOperator(
**self.processing_config_kwargs, config=create_processing_params
)
- self.assertRaises(AirflowException, sagemaker.execute, None)
+ with pytest.raises(AirflowException):
+ sagemaker.execute(None)
@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "list_processing_jobs", return_value=[{"ProcessingJobName": job_name}])
@@ -176,11 +178,13 @@ def test_execute_with_existing_job_fail(
**self.processing_config_kwargs, config=create_processing_params
)
sagemaker.action_if_job_exists = "fail"
- self.assertRaises(AirflowException, sagemaker.execute, None)
+ with pytest.raises(AirflowException):
+ sagemaker.execute(None)
@mock.patch.object(SageMakerHook, "get_conn")
def test_action_if_job_exists_validation(self, mock_client):
sagemaker = SageMakerProcessingOperator(
**self.processing_config_kwargs, config=create_processing_params
)
- self.assertRaises(AirflowException, sagemaker.__init__, action_if_job_exists="not_fail_or_increment")
+ with pytest.raises(AirflowException):
+ sagemaker.__init__(action_if_job_exists="not_fail_or_increment")
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
index 7424d6f7f8b25..4aeca8c65e077 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
@@ -18,6 +18,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker_training import SageMakerTrainingOperator
@@ -72,17 +74,14 @@ def setUp(self):
def test_parse_config_integers(self):
self.sagemaker.parse_config_integers()
- self.assertEqual(
- self.sagemaker.config['ResourceConfig']['InstanceCount'],
- int(self.sagemaker.config['ResourceConfig']['InstanceCount']),
+ assert self.sagemaker.config['ResourceConfig']['InstanceCount'] == int(
+ self.sagemaker.config['ResourceConfig']['InstanceCount']
)
- self.assertEqual(
- self.sagemaker.config['ResourceConfig']['VolumeSizeInGB'],
- int(self.sagemaker.config['ResourceConfig']['VolumeSizeInGB']),
+ assert self.sagemaker.config['ResourceConfig']['VolumeSizeInGB'] == int(
+ self.sagemaker.config['ResourceConfig']['VolumeSizeInGB']
)
- self.assertEqual(
- self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds'],
- int(self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds']),
+ assert self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds'] == int(
+ self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds']
)
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -108,7 +107,8 @@ def test_execute_with_failure(self, mock_training, mock_client):
'TrainingJobArn': 'testarn',
'ResponseMetadata': {'HTTPStatusCode': 404},
}
- self.assertRaises(AirflowException, self.sagemaker.execute, None)
+ with pytest.raises(AirflowException):
+ self.sagemaker.execute(None)
# pylint: enable=unused-argument
@@ -143,4 +143,5 @@ def test_execute_with_existing_job_fail(
self.sagemaker.action_if_job_exists = "fail"
mock_create_training_job.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}}
mock_list_training_jobs.return_value = [{"TrainingJobName": job_name}]
- self.assertRaises(AirflowException, self.sagemaker.execute, None)
+ with pytest.raises(AirflowException):
+ self.sagemaker.execute(None)
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index 739baed701dbe..6ca4dc9334a2a 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker_transform import SageMakerTransformOperator
@@ -76,12 +78,11 @@ def setUp(self):
def test_parse_config_integers(self):
self.sagemaker.parse_config_integers()
test_config = self.sagemaker.config['Transform']
- self.assertEqual(
- test_config['TransformResources']['InstanceCount'],
- int(test_config['TransformResources']['InstanceCount']),
+ assert test_config['TransformResources']['InstanceCount'] == int(
+ test_config['TransformResources']['InstanceCount']
)
- self.assertEqual(test_config['MaxConcurrentTransforms'], int(test_config['MaxConcurrentTransforms']))
- self.assertEqual(test_config['MaxPayloadInMB'], int(test_config['MaxPayloadInMB']))
+ assert test_config['MaxConcurrentTransforms'] == int(test_config['MaxConcurrentTransforms'])
+ assert test_config['MaxPayloadInMB'] == int(test_config['MaxPayloadInMB'])
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_model')
@@ -105,4 +106,5 @@ def test_execute_with_failure(self, mock_transform, mock_model, mock_client):
'TransformJobArn': 'testarn',
'ResponseMetadata': {'HTTPStatusCode': 404},
}
- self.assertRaises(AirflowException, self.sagemaker.execute, None)
+ with pytest.raises(AirflowException):
+ self.sagemaker.execute(None)
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
index 9978944a18eec..3982bcef6a623 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker_tuning import SageMakerTuningOperator
@@ -90,33 +92,25 @@ def setUp(self):
def test_parse_config_integers(self):
self.sagemaker.parse_config_integers()
- self.assertEqual(
- self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['InstanceCount'],
- int(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['InstanceCount']),
+ assert self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['InstanceCount'] == int(
+ self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['InstanceCount']
)
- self.assertEqual(
- self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['VolumeSizeInGB'],
- int(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['VolumeSizeInGB']),
+ assert self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['VolumeSizeInGB'] == int(
+ self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['VolumeSizeInGB']
)
- self.assertEqual(
+ assert self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][
+ 'MaxNumberOfTrainingJobs'
+ ] == int(
self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][
'MaxNumberOfTrainingJobs'
- ],
- int(
- self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][
- 'MaxNumberOfTrainingJobs'
- ]
- ),
+ ]
)
- self.assertEqual(
+ assert self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][
+ 'MaxParallelTrainingJobs'
+ ] == int(
self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][
'MaxParallelTrainingJobs'
- ],
- int(
- self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][
- 'MaxParallelTrainingJobs'
- ]
- ),
+ ]
)
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -132,4 +126,5 @@ def test_execute(self, mock_tuning, mock_client):
@mock.patch.object(SageMakerHook, 'create_tuning_job')
def test_execute_with_failure(self, mock_tuning, mock_client):
mock_tuning.return_value = {'TrainingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
- self.assertRaises(AirflowException, self.sagemaker.execute, None)
+ with pytest.raises(AirflowException):
+ self.sagemaker.execute(None)
diff --git a/tests/providers/amazon/aws/operators/test_sns.py b/tests/providers/amazon/aws/operators/test_sns.py
index 78eda16e530e7..560d53729da0e 100644
--- a/tests/providers/amazon/aws/operators/test_sns.py
+++ b/tests/providers/amazon/aws/operators/test_sns.py
@@ -43,12 +43,12 @@ def test_init(self):
)
# Then
- self.assertEqual(TASK_ID, operator.task_id)
- self.assertEqual(AWS_CONN_ID, operator.aws_conn_id)
- self.assertEqual(TARGET_ARN, operator.target_arn)
- self.assertEqual(MESSAGE, operator.message)
- self.assertEqual(SUBJECT, operator.subject)
- self.assertEqual(MESSAGE_ATTRIBUTES, operator.message_attributes)
+ assert TASK_ID == operator.task_id
+ assert AWS_CONN_ID == operator.aws_conn_id
+ assert TARGET_ARN == operator.target_arn
+ assert MESSAGE == operator.message
+ assert SUBJECT == operator.subject
+ assert MESSAGE_ATTRIBUTES == operator.message_attributes
@mock.patch('airflow.providers.amazon.aws.operators.sns.AwsSnsHook')
def test_execute(self, mock_hook):
@@ -71,4 +71,4 @@ def test_execute(self, mock_hook):
result = operator.execute(None)
# Then
- self.assertEqual(hook_response, result)
+ assert hook_response == result
diff --git a/tests/providers/amazon/aws/operators/test_sqs.py b/tests/providers/amazon/aws/operators/test_sqs.py
index ef65b67c6d8e1..6da41cb55f296 100644
--- a/tests/providers/amazon/aws/operators/test_sqs.py
+++ b/tests/providers/amazon/aws/operators/test_sqs.py
@@ -51,15 +51,15 @@ def test_execute_success(self):
self.sqs_hook.create_queue('test')
result = self.operator.execute(self.mock_context)
- self.assertTrue('MD5OfMessageBody' in result)
- self.assertTrue('MessageId' in result)
+ assert 'MD5OfMessageBody' in result
+ assert 'MessageId' in result
message = self.sqs_hook.get_conn().receive_message(QueueUrl='test')
- self.assertEqual(len(message['Messages']), 1)
- self.assertEqual(message['Messages'][0]['MessageId'], result['MessageId'])
- self.assertEqual(message['Messages'][0]['Body'], 'hello')
+ assert len(message['Messages']) == 1
+ assert message['Messages'][0]['MessageId'] == result['MessageId']
+ assert message['Messages'][0]['Body'] == 'hello'
context_calls = []
- self.assertTrue(self.mock_context['ti'].method_calls == context_calls, "context call should be same")
+ assert self.mock_context['ti'].method_calls == context_calls, "context call should be same"
diff --git a/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py
index 2aba95bc29d99..8bad0321b359c 100644
--- a/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py
+++ b/tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py
@@ -45,10 +45,10 @@ def test_init(self):
)
# Then
- self.assertEqual(TASK_ID, operator.task_id)
- self.assertEqual(EXECUTION_ARN, operator.execution_arn)
- self.assertEqual(AWS_CONN_ID, operator.aws_conn_id)
- self.assertEqual(REGION_NAME, operator.region_name)
+ assert TASK_ID == operator.task_id
+ assert EXECUTION_ARN == operator.execution_arn
+ assert AWS_CONN_ID == operator.aws_conn_id
+ assert REGION_NAME == operator.region_name
@mock.patch('airflow.providers.amazon.aws.operators.step_function_get_execution_output.StepFunctionHook')
def test_execute(self, mock_hook):
@@ -66,4 +66,4 @@ def test_execute(self, mock_hook):
result = operator.execute(self.mock_context)
# Then
- self.assertEqual({}, result)
+ assert {} == result
diff --git a/tests/providers/amazon/aws/operators/test_step_function_start_execution.py b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py
index 4ceedbcd37c6d..f71b8f02bb638 100644
--- a/tests/providers/amazon/aws/operators/test_step_function_start_execution.py
+++ b/tests/providers/amazon/aws/operators/test_step_function_start_execution.py
@@ -49,12 +49,12 @@ def test_init(self):
)
# Then
- self.assertEqual(TASK_ID, operator.task_id)
- self.assertEqual(STATE_MACHINE_ARN, operator.state_machine_arn)
- self.assertEqual(NAME, operator.name)
- self.assertEqual(INPUT, operator.input)
- self.assertEqual(AWS_CONN_ID, operator.aws_conn_id)
- self.assertEqual(REGION_NAME, operator.region_name)
+ assert TASK_ID == operator.task_id
+ assert STATE_MACHINE_ARN == operator.state_machine_arn
+ assert NAME == operator.name
+ assert INPUT == operator.input
+ assert AWS_CONN_ID == operator.aws_conn_id
+ assert REGION_NAME == operator.region_name
@mock.patch('airflow.providers.amazon.aws.operators.step_function_start_execution.StepFunctionHook')
def test_execute(self, mock_hook):
@@ -80,4 +80,4 @@ def test_execute(self, mock_hook):
result = operator.execute(self.mock_context)
# Then
- self.assertEqual(hook_response, result)
+ assert hook_response == result
diff --git a/tests/providers/amazon/aws/secrets/test_secrets_manager.py b/tests/providers/amazon/aws/secrets/test_secrets_manager.py
index bb6e272a00896..d45f1aa89cc73 100644
--- a/tests/providers/amazon/aws/secrets/test_secrets_manager.py
+++ b/tests/providers/amazon/aws/secrets/test_secrets_manager.py
@@ -41,7 +41,7 @@ def test_get_conn_uri(self):
secrets_manager_backend.client.put_secret_value(**param)
returned_uri = secrets_manager_backend.get_conn_uri(conn_id="test_postgres")
- self.assertEqual('postgresql://airflow:airflow@host:5432/airflow', returned_uri)
+ assert 'postgresql://airflow:airflow@host:5432/airflow' == returned_uri
@mock_secretsmanager
def test_get_conn_uri_non_existent_key(self):
@@ -58,8 +58,8 @@ def test_get_conn_uri_non_existent_key(self):
secrets_manager_backend = SecretsManagerBackend()
secrets_manager_backend.client.put_secret_value(**param)
- self.assertIsNone(secrets_manager_backend.get_conn_uri(conn_id=conn_id))
- self.assertEqual([], secrets_manager_backend.get_connections(conn_id=conn_id))
+ assert secrets_manager_backend.get_conn_uri(conn_id=conn_id) is None
+ assert [] == secrets_manager_backend.get_connections(conn_id=conn_id)
@mock_secretsmanager
def test_get_variable(self):
@@ -69,7 +69,7 @@ def test_get_variable(self):
secrets_manager_backend.client.put_secret_value(**param)
returned_uri = secrets_manager_backend.get_variable('hello')
- self.assertEqual('world', returned_uri)
+ assert 'world' == returned_uri
@mock_secretsmanager
def test_get_variable_non_existent_key(self):
@@ -82,7 +82,7 @@ def test_get_variable_non_existent_key(self):
secrets_manager_backend = SecretsManagerBackend()
secrets_manager_backend.client.put_secret_value(**param)
- self.assertIsNone(secrets_manager_backend.get_variable("test_mysql"))
+ assert secrets_manager_backend.get_variable("test_mysql") is None
@mock.patch("airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend._get_secret")
def test_connection_prefix_none_value(self, mock_get_secret):
@@ -95,7 +95,7 @@ def test_connection_prefix_none_value(self, mock_get_secret):
secrets_manager_backend = SecretsManagerBackend(**kwargs)
- self.assertIsNone(secrets_manager_backend.get_conn_uri("test_mysql"))
+ assert secrets_manager_backend.get_conn_uri("test_mysql") is None
mock_get_secret.assert_not_called()
@mock.patch("airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend._get_secret")
@@ -109,7 +109,7 @@ def test_variable_prefix_none_value(self, mock_get_secret):
secrets_manager_backend = SecretsManagerBackend(**kwargs)
- self.assertIsNone(secrets_manager_backend.get_variable("hello"))
+ assert secrets_manager_backend.get_variable("hello") is None
mock_get_secret.assert_not_called()
@mock.patch("airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend._get_secret")
@@ -123,5 +123,5 @@ def test_config_prefix_none_value(self, mock_get_secret):
secrets_manager_backend = SecretsManagerBackend(**kwargs)
- self.assertIsNone(secrets_manager_backend.get_config("config"))
+ assert secrets_manager_backend.get_config("config") is None
mock_get_secret.assert_not_called()
diff --git a/tests/providers/amazon/aws/secrets/test_systems_manager.py b/tests/providers/amazon/aws/secrets/test_systems_manager.py
index 01a46693b7a9b..da506b27ed977 100644
--- a/tests/providers/amazon/aws/secrets/test_systems_manager.py
+++ b/tests/providers/amazon/aws/secrets/test_systems_manager.py
@@ -47,7 +47,7 @@ def test_get_conn_uri(self):
ssm_backend.client.put_parameter(**param)
returned_uri = ssm_backend.get_conn_uri(conn_id="test_postgres")
- self.assertEqual('postgresql://airflow:airflow@host:5432/airflow', returned_uri)
+ assert 'postgresql://airflow:airflow@host:5432/airflow' == returned_uri
@mock_ssm
def test_get_conn_uri_non_existent_key(self):
@@ -65,8 +65,8 @@ def test_get_conn_uri_non_existent_key(self):
ssm_backend = SystemsManagerParameterStoreBackend()
ssm_backend.client.put_parameter(**param)
- self.assertIsNone(ssm_backend.get_conn_uri(conn_id=conn_id))
- self.assertEqual([], ssm_backend.get_connections(conn_id=conn_id))
+ assert ssm_backend.get_conn_uri(conn_id=conn_id) is None
+ assert [] == ssm_backend.get_connections(conn_id=conn_id)
@mock_ssm
def test_get_variable(self):
@@ -76,7 +76,7 @@ def test_get_variable(self):
ssm_backend.client.put_parameter(**param)
returned_uri = ssm_backend.get_variable('hello')
- self.assertEqual('world', returned_uri)
+ assert 'world' == returned_uri
@mock_ssm
def test_get_config(self):
@@ -90,7 +90,7 @@ def test_get_config(self):
ssm_backend.client.put_parameter(**param)
returned_uri = ssm_backend.get_config('sql_alchemy_conn')
- self.assertEqual('sqlite:///Users/test_user/airflow.db', returned_uri)
+ assert 'sqlite:///Users/test_user/airflow.db' == returned_uri
@mock_ssm
def test_get_variable_secret_string(self):
@@ -98,7 +98,7 @@ def test_get_variable_secret_string(self):
ssm_backend = SystemsManagerParameterStoreBackend()
ssm_backend.client.put_parameter(**param)
returned_uri = ssm_backend.get_variable('hello')
- self.assertEqual('world', returned_uri)
+ assert 'world' == returned_uri
@mock_ssm
def test_get_variable_non_existent_key(self):
@@ -111,7 +111,7 @@ def test_get_variable_non_existent_key(self):
ssm_backend = SystemsManagerParameterStoreBackend()
ssm_backend.client.put_parameter(**param)
- self.assertIsNone(ssm_backend.get_variable("test_mysql"))
+ assert ssm_backend.get_variable("test_mysql") is None
@conf_vars(
{
@@ -146,7 +146,7 @@ def test_connection_prefix_none_value(self, mock_get_secret):
ssm_backend = SystemsManagerParameterStoreBackend(**kwargs)
- self.assertIsNone(ssm_backend.get_conn_uri("test_mysql"))
+ assert ssm_backend.get_conn_uri("test_mysql") is None
mock_get_secret.assert_not_called()
@mock.patch(
@@ -163,7 +163,7 @@ def test_variable_prefix_none_value(self, mock_get_secret):
ssm_backend = SystemsManagerParameterStoreBackend(**kwargs)
- self.assertIsNone(ssm_backend.get_variable("hello"))
+ assert ssm_backend.get_variable("hello") is None
mock_get_secret.assert_not_called()
@mock.patch(
@@ -180,5 +180,5 @@ def test_config_prefix_none_value(self, mock_get_secret):
ssm_backend = SystemsManagerParameterStoreBackend(**kwargs)
- self.assertIsNone(ssm_backend.get_config("config"))
+ assert ssm_backend.get_config("config") is None
mock_get_secret.assert_not_called()
diff --git a/tests/providers/amazon/aws/sensors/test_athena.py b/tests/providers/amazon/aws/sensors/test_athena.py
index 8f5f2c7173cf7..781f94cdb190a 100644
--- a/tests/providers/amazon/aws/sensors/test_athena.py
+++ b/tests/providers/amazon/aws/sensors/test_athena.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook
from airflow.providers.amazon.aws.sensors.athena import AthenaSensor
@@ -36,24 +38,24 @@ def setUp(self):
@mock.patch.object(AWSAthenaHook, 'poll_query_status', side_effect=("SUCCEEDED",))
def test_poke_success(self, mock_poll_query_status):
- self.assertTrue(self.sensor.poke(None))
+ assert self.sensor.poke(None)
@mock.patch.object(AWSAthenaHook, 'poll_query_status', side_effect=("RUNNING",))
def test_poke_running(self, mock_poll_query_status):
- self.assertFalse(self.sensor.poke(None))
+ assert not self.sensor.poke(None)
@mock.patch.object(AWSAthenaHook, 'poll_query_status', side_effect=("QUEUED",))
def test_poke_queued(self, mock_poll_query_status):
- self.assertFalse(self.sensor.poke(None))
+ assert not self.sensor.poke(None)
@mock.patch.object(AWSAthenaHook, 'poll_query_status', side_effect=("FAILED",))
def test_poke_failed(self, mock_poll_query_status):
- with self.assertRaises(AirflowException) as context:
+ with pytest.raises(AirflowException) as ctx:
self.sensor.poke(None)
- self.assertIn('Athena sensor failed', str(context.exception))
+ assert 'Athena sensor failed' in str(ctx.value)
@mock.patch.object(AWSAthenaHook, 'poll_query_status', side_effect=("CANCELLED",))
def test_poke_cancelled(self, mock_poll_query_status):
- with self.assertRaises(AirflowException) as context:
+ with pytest.raises(AirflowException) as ctx:
self.sensor.poke(None)
- self.assertIn('Athena sensor failed', str(context.exception))
+ assert 'Athena sensor failed' in str(ctx.value)
diff --git a/tests/providers/amazon/aws/sensors/test_cloud_formation.py b/tests/providers/amazon/aws/sensors/test_cloud_formation.py
index 382933b60fa08..a2c15dac448d3 100644
--- a/tests/providers/amazon/aws/sensors/test_cloud_formation.py
+++ b/tests/providers/amazon/aws/sensors/test_cloud_formation.py
@@ -19,6 +19,7 @@
from unittest.mock import MagicMock, patch
import boto3
+import pytest
from airflow.providers.amazon.aws.sensors.cloud_formation import (
CloudFormationCreateStackSensor,
@@ -53,7 +54,7 @@ def test_poke(self):
stack_name = 'foobar'
self.client.create_stack(StackName=stack_name, TemplateBody='{"Resources": {}}')
op = CloudFormationCreateStackSensor(task_id='task', stack_name='foobar')
- self.assertTrue(op.poke({}))
+ assert op.poke({})
def test_poke_false(self):
with patch('boto3.session.Session', self.boto3_session_mock):
@@ -61,18 +62,18 @@ def test_poke_false(self):
'Stacks': [{'StackStatus': 'CREATE_IN_PROGRESS'}]
}
op = CloudFormationCreateStackSensor(task_id='task', stack_name='foo')
- self.assertFalse(op.poke({}))
+ assert not op.poke({})
def test_poke_stack_in_unsuccessful_state(self):
with patch('boto3.session.Session', self.boto3_session_mock):
self.cloudformation_client_mock.describe_stacks.return_value = {
'Stacks': [{'StackStatus': 'bar'}]
}
- with self.assertRaises(ValueError) as error:
+ with pytest.raises(ValueError) as ctx:
op = CloudFormationCreateStackSensor(task_id='task', stack_name='foo')
op.poke({})
- self.assertEqual('Stack foo in bad state: bar', str(error.exception))
+ assert 'Stack foo in bad state: bar' == str(ctx.value)
@unittest.skipIf(
@@ -98,7 +99,7 @@ def test_poke(self):
self.client.create_stack(StackName=stack_name, TemplateBody='{"Resources": {}}')
self.client.delete_stack(StackName=stack_name)
op = CloudFormationDeleteStackSensor(task_id='task', stack_name=stack_name)
- self.assertTrue(op.poke({}))
+ assert op.poke({})
def test_poke_false(self):
with patch('boto3.session.Session', self.boto3_session_mock):
@@ -106,20 +107,20 @@ def test_poke_false(self):
'Stacks': [{'StackStatus': 'DELETE_IN_PROGRESS'}]
}
op = CloudFormationDeleteStackSensor(task_id='task', stack_name='foo')
- self.assertFalse(op.poke({}))
+ assert not op.poke({})
def test_poke_stack_in_unsuccessful_state(self):
with patch('boto3.session.Session', self.boto3_session_mock):
self.cloudformation_client_mock.describe_stacks.return_value = {
'Stacks': [{'StackStatus': 'bar'}]
}
- with self.assertRaises(ValueError) as error:
+ with pytest.raises(ValueError) as ctx:
op = CloudFormationDeleteStackSensor(task_id='task', stack_name='foo')
op.poke({})
- self.assertEqual('Stack foo in bad state: bar', str(error.exception))
+ assert 'Stack foo in bad state: bar' == str(ctx.value)
@mock_cloudformation
def test_poke_stack_does_not_exist(self):
op = CloudFormationDeleteStackSensor(task_id='task', stack_name='foo')
- self.assertTrue(op.poke({}))
+ assert op.poke({})
diff --git a/tests/providers/amazon/aws/sensors/test_ec2_instance_state.py b/tests/providers/amazon/aws/sensors/test_ec2_instance_state.py
index 03f313939d48a..e715da291f30d 100644
--- a/tests/providers/amazon/aws/sensors/test_ec2_instance_state.py
+++ b/tests/providers/amazon/aws/sensors/test_ec2_instance_state.py
@@ -19,6 +19,7 @@
import unittest
+import pytest
from moto import mock_ec2
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
@@ -34,22 +35,22 @@ def test_init(self):
aws_conn_id="aws_conn_test",
region_name="region-test",
)
- self.assertEqual(ec2_operator.task_id, "task_test")
- self.assertEqual(ec2_operator.target_state, "stopped")
- self.assertEqual(ec2_operator.instance_id, "i-123abc")
- self.assertEqual(ec2_operator.aws_conn_id, "aws_conn_test")
- self.assertEqual(ec2_operator.region_name, "region-test")
+ assert ec2_operator.task_id == "task_test"
+ assert ec2_operator.target_state == "stopped"
+ assert ec2_operator.instance_id == "i-123abc"
+ assert ec2_operator.aws_conn_id == "aws_conn_test"
+ assert ec2_operator.region_name == "region-test"
def test_init_invalid_target_state(self):
invalid_target_state = "target_state_test"
- with self.assertRaises(ValueError) as cm:
+ with pytest.raises(ValueError) as ctx:
EC2InstanceStateSensor(
task_id="task_test",
target_state=invalid_target_state,
instance_id="i-123abc",
)
msg = f"Invalid target_state: {invalid_target_state}"
- self.assertEqual(str(cm.exception), msg)
+ assert str(ctx.value) == msg
@mock_ec2
def test_running(self):
@@ -70,11 +71,11 @@ def test_running(self):
instance_id=instance_id,
)
# assert instance state is not running
- self.assertFalse(start_sensor.poke(None))
+ assert not start_sensor.poke(None)
# start instance
ec2_hook.get_instance(instance_id=instance_id).start()
# assert instance state is running
- self.assertTrue(start_sensor.poke(None))
+ assert start_sensor.poke(None)
@mock_ec2
def test_stopped(self):
@@ -95,11 +96,11 @@ def test_stopped(self):
instance_id=instance_id,
)
# assert instance state is not stopped
- self.assertFalse(stop_sensor.poke(None))
+ assert not stop_sensor.poke(None)
# stop instance
ec2_hook.get_instance(instance_id=instance_id).stop()
# assert instance state is stopped
- self.assertTrue(stop_sensor.poke(None))
+ assert stop_sensor.poke(None)
@mock_ec2
def test_terminated(self):
@@ -120,8 +121,8 @@ def test_terminated(self):
instance_id=instance_id,
)
# assert instance state is not terminated
- self.assertFalse(stop_sensor.poke(None))
+ assert not stop_sensor.poke(None)
# stop instance
ec2_hook.get_instance(instance_id=instance_id).terminate()
# assert instance state is terminated
- self.assertTrue(stop_sensor.poke(None))
+ assert stop_sensor.poke(None)
diff --git a/tests/providers/amazon/aws/sensors/test_emr_base.py b/tests/providers/amazon/aws/sensors/test_emr_base.py
index 2f2005a4bc6dc..3ccaf28a9f2fe 100644
--- a/tests/providers/amazon/aws/sensors/test_emr_base.py
+++ b/tests/providers/amazon/aws/sensors/test_emr_base.py
@@ -18,6 +18,8 @@
import unittest
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor
@@ -79,7 +81,7 @@ def test_poke_returns_false_when_state_is_not_in_target_states(self):
'ResponseMetadata': {'HTTPStatusCode': GOOD_HTTP_STATUS},
}
- self.assertEqual(operator.poke(None), False)
+ assert operator.poke(None) is False
def test_poke_returns_false_when_http_response_is_bad(self):
operator = EmrBaseSensorSubclass(
@@ -91,7 +93,7 @@ def test_poke_returns_false_when_http_response_is_bad(self):
'ResponseMetadata': {'HTTPStatusCode': BAD_HTTP_STATUS},
}
- self.assertEqual(operator.poke(None), False)
+ assert operator.poke(None) is False
def test_poke_raises_error_when_state_is_in_failed_states(self):
operator = EmrBaseSensorSubclass(
@@ -103,9 +105,9 @@ def test_poke_raises_error_when_state_is_in_failed_states(self):
'ResponseMetadata': {'HTTPStatusCode': GOOD_HTTP_STATUS},
}
- with self.assertRaises(AirflowException) as context:
+ with pytest.raises(AirflowException) as ctx:
operator.poke(None)
- self.assertIn('EMR job failed', str(context.exception))
- self.assertIn(EXPECTED_CODE, str(context.exception))
- self.assertNotIn(EMPTY_CODE, str(context.exception))
+ assert 'EMR job failed' in str(ctx.value)
+ assert EXPECTED_CODE in str(ctx.value)
+ assert EMPTY_CODE not in str(ctx.value)
diff --git a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py
index a98f8f7ea5ab8..8aba20b0d71cb 100644
--- a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py
+++ b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py
@@ -20,6 +20,7 @@
import unittest
from unittest.mock import MagicMock, patch
+import pytest
from dateutil.tz import tzlocal
from airflow.exceptions import AirflowException
@@ -211,7 +212,7 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self
operator.execute(None)
# make sure we called twice
- self.assertEqual(self.mock_emr_client.describe_cluster.call_count, 3)
+ assert self.mock_emr_client.describe_cluster.call_count == 3
# make sure it was called with the job_flow_id
calls = [unittest.mock.call(ClusterId='j-8989898989')]
@@ -227,11 +228,11 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_failed_state_with_e
task_id='test_task', poke_interval=0, job_flow_id='j-8989898989', aws_conn_id='aws_default'
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
operator.execute(None)
# make sure we called twice
- self.assertEqual(self.mock_emr_client.describe_cluster.call_count, 2)
+ assert self.mock_emr_client.describe_cluster.call_count == 2
# make sure it was called with the job_flow_id
self.mock_emr_client.describe_cluster.assert_called_once_with(ClusterId='j-8989898989')
@@ -257,7 +258,7 @@ def test_different_target_states(self):
operator.execute(None)
# make sure we called twice
- self.assertEqual(self.mock_emr_client.describe_cluster.call_count, 3)
+ assert self.mock_emr_client.describe_cluster.call_count == 3
# make sure it was called with the job_flow_id
calls = [unittest.mock.call(ClusterId='j-8989898989')]
diff --git a/tests/providers/amazon/aws/sensors/test_emr_step.py b/tests/providers/amazon/aws/sensors/test_emr_step.py
index 512ec5b972486..5319ef18d4fa9 100644
--- a/tests/providers/amazon/aws/sensors/test_emr_step.py
+++ b/tests/providers/amazon/aws/sensors/test_emr_step.py
@@ -20,6 +20,7 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
+import pytest
from dateutil.tz import tzlocal
from airflow.exceptions import AirflowException
@@ -166,7 +167,7 @@ def test_step_completed(self):
with patch('boto3.session.Session', self.boto3_session_mock):
self.sensor.execute(None)
- self.assertEqual(self.emr_client_mock.describe_step.call_count, 2)
+ assert self.emr_client_mock.describe_step.call_count == 2
calls = [
unittest.mock.call(ClusterId='j-8989898989', StepId='s-VK57YR1Z9Z5N'),
unittest.mock.call(ClusterId='j-8989898989', StepId='s-VK57YR1Z9Z5N'),
@@ -180,7 +181,8 @@ def test_step_cancelled(self):
]
with patch('boto3.session.Session', self.boto3_session_mock):
- self.assertRaises(AirflowException, self.sensor.execute, None)
+ with pytest.raises(AirflowException):
+ self.sensor.execute(None)
def test_step_failed(self):
self.emr_client_mock.describe_step.side_effect = [
@@ -189,7 +191,8 @@ def test_step_failed(self):
]
with patch('boto3.session.Session', self.boto3_session_mock):
- self.assertRaises(AirflowException, self.sensor.execute, None)
+ with pytest.raises(AirflowException):
+ self.sensor.execute(None)
def test_step_interrupted(self):
self.emr_client_mock.describe_step.side_effect = [
@@ -198,4 +201,5 @@ def test_step_interrupted(self):
]
with patch('boto3.session.Session', self.boto3_session_mock):
- self.assertRaises(AirflowException, self.sensor.execute, None)
+ with pytest.raises(AirflowException):
+ self.sensor.execute(None)
diff --git a/tests/providers/amazon/aws/sensors/test_glacier.py b/tests/providers/amazon/aws/sensors/test_glacier.py
index d954c6e7ed599..5e4670a8de136 100644
--- a/tests/providers/amazon/aws/sensors/test_glacier.py
+++ b/tests/providers/amazon/aws/sensors/test_glacier.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow import AirflowException
from airflow.providers.amazon.aws.sensors.glacier import GlacierJobOperationSensor, JobStatus
@@ -41,28 +43,28 @@ def setUp(self):
side_effect=[{"Action": "", "StatusCode": JobStatus.SUCCEEDED.value}],
)
def test_poke_succeeded(self, _):
- self.assertTrue(self.op.poke(None))
+ assert self.op.poke(None)
@mock.patch(
"airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job",
side_effect=[{"Action": "", "StatusCode": JobStatus.IN_PROGRESS.value}],
)
def test_poke_in_progress(self, _):
- self.assertFalse(self.op.poke(None))
+ assert not self.op.poke(None)
@mock.patch(
"airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job",
side_effect=[{"Action": "", "StatusCode": ""}],
)
def test_poke_fail(self, _):
- with self.assertRaises(AirflowException) as context:
+ with pytest.raises(AirflowException) as ctx:
self.op.poke(None)
- self.assertIn('Sensor failed', str(context.exception))
+ assert 'Sensor failed' in str(ctx.value)
class TestSensorJobDescription(unittest.TestCase):
def test_job_status_success(self):
- self.assertEqual(JobStatus.SUCCEEDED.value, SUCCEEDED)
+ assert JobStatus.SUCCEEDED.value == SUCCEEDED
def test_job_status_in_progress(self):
- self.assertEqual(JobStatus.IN_PROGRESS.value, IN_PROGRESS)
+ assert JobStatus.IN_PROGRESS.value == IN_PROGRESS
diff --git a/tests/providers/amazon/aws/sensors/test_glue.py b/tests/providers/amazon/aws/sensors/test_glue.py
index ff503b1467970..4a0c401b2a326 100644
--- a/tests/providers/amazon/aws/sensors/test_glue.py
+++ b/tests/providers/amazon/aws/sensors/test_glue.py
@@ -40,7 +40,7 @@ def test_poke(self, mock_get_job_state, mock_conn):
timeout=5,
aws_conn_id='aws_default',
)
- self.assertTrue(op.poke(None))
+ assert op.poke(None)
@mock.patch.object(AwsGlueJobHook, 'get_conn')
@mock.patch.object(AwsGlueJobHook, 'get_job_state')
@@ -55,7 +55,7 @@ def test_poke_false(self, mock_get_job_state, mock_conn):
timeout=5,
aws_conn_id='aws_default',
)
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
if __name__ == '__main__':
diff --git a/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py b/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py
index a4c264bbe6359..8f41cc5143524 100644
--- a/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py
+++ b/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py
@@ -38,14 +38,14 @@ class TestAwsGlueCatalogPartitionSensor(unittest.TestCase):
def test_poke(self, mock_check_for_partition):
mock_check_for_partition.return_value = True
op = AwsGlueCatalogPartitionSensor(task_id=self.task_id, table_name='tbl')
- self.assertTrue(op.poke(None))
+ assert op.poke(None)
@mock_glue
@mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
def test_poke_false(self, mock_check_for_partition):
mock_check_for_partition.return_value = False
op = AwsGlueCatalogPartitionSensor(task_id=self.task_id, table_name='tbl')
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
@mock_glue
@mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
@@ -54,8 +54,8 @@ def test_poke_default_args(self, mock_check_for_partition):
op = AwsGlueCatalogPartitionSensor(task_id=self.task_id, table_name=table_name)
op.poke(None)
- self.assertEqual(op.hook.region_name, None)
- self.assertEqual(op.hook.aws_conn_id, 'aws_default')
+ assert op.hook.region_name is None
+ assert op.hook.aws_conn_id == 'aws_default'
mock_check_for_partition.assert_called_once_with('default', table_name, "ds='{{ ds }}'")
@mock_glue
@@ -80,10 +80,10 @@ def test_poke_nondefault_args(self, mock_check_for_partition):
)
op.poke(None)
- self.assertEqual(op.hook.region_name, region_name)
- self.assertEqual(op.hook.aws_conn_id, aws_conn_id)
- self.assertEqual(op.poke_interval, poke_interval)
- self.assertEqual(op.timeout, timeout)
+ assert op.hook.region_name == region_name
+ assert op.hook.aws_conn_id == aws_conn_id
+ assert op.poke_interval == poke_interval
+ assert op.timeout == timeout
mock_check_for_partition.assert_called_once_with(database_name, table_name, expression)
@mock_glue
diff --git a/tests/providers/amazon/aws/sensors/test_redshift.py b/tests/providers/amazon/aws/sensors/test_redshift.py
index ddd3e3179f6b9..ec7ae66ab5317 100644
--- a/tests/providers/amazon/aws/sensors/test_redshift.py
+++ b/tests/providers/amazon/aws/sensors/test_redshift.py
@@ -54,7 +54,7 @@ def test_poke(self):
cluster_identifier='test_cluster',
target_status='available',
)
- self.assertTrue(op.poke(None))
+ assert op.poke(None)
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
@mock_redshift
@@ -69,7 +69,7 @@ def test_poke_false(self):
target_status='available',
)
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
@mock_redshift
@@ -84,4 +84,4 @@ def test_poke_cluster_not_found(self):
target_status='cluster_not_found',
)
- self.assertTrue(op.poke(None))
+ assert op.poke(None)
diff --git a/tests/providers/amazon/aws/sensors/test_s3_key.py b/tests/providers/amazon/aws/sensors/test_s3_key.py
index 8ef07b3b31ca1..a4c57f8c39b71 100644
--- a/tests/providers/amazon/aws/sensors/test_s3_key.py
+++ b/tests/providers/amazon/aws/sensors/test_s3_key.py
@@ -20,6 +20,7 @@
from datetime import datetime
from unittest import mock
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -37,7 +38,7 @@ def test_bucket_name_none_and_bucket_key_as_relative_path(self):
:return:
"""
op = S3KeySensor(task_id='s3_key_sensor', bucket_key="file_in_bucket")
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
op.poke(None)
def test_bucket_name_provided_and_bucket_key_is_s3_url(self):
@@ -49,7 +50,7 @@ def test_bucket_name_provided_and_bucket_key_is_s3_url(self):
op = S3KeySensor(
task_id='s3_key_sensor', bucket_key="s3://test_bucket/file", bucket_name='test_bucket'
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
op.poke(None)
@parameterized.expand(
@@ -70,8 +71,8 @@ def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_hoo
op.poke(None)
- self.assertEqual(op.bucket_key, parsed_key)
- self.assertEqual(op.bucket_name, parsed_bucket)
+ assert op.bucket_key == parsed_key
+ assert op.bucket_name == parsed_bucket
@mock.patch('airflow.providers.amazon.aws.sensors.s3_key.S3Hook')
def test_parse_bucket_key_from_jinja(self, mock_hook):
@@ -95,8 +96,8 @@ def test_parse_bucket_key_from_jinja(self, mock_hook):
op.poke(None)
- self.assertEqual(op.bucket_key, "key")
- self.assertEqual(op.bucket_name, "bucket")
+ assert op.bucket_key == "key"
+ assert op.bucket_name == "bucket"
@mock.patch('airflow.providers.amazon.aws.sensors.s3_key.S3Hook')
def test_poke(self, mock_hook):
@@ -104,11 +105,11 @@ def test_poke(self, mock_hook):
mock_check_for_key = mock_hook.return_value.check_for_key
mock_check_for_key.return_value = False
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name)
mock_hook.return_value.check_for_key.return_value = True
- self.assertTrue(op.poke(None))
+ assert op.poke(None)
@mock.patch('airflow.providers.amazon.aws.sensors.s3_key.S3Hook')
def test_poke_wildcard(self, mock_hook):
@@ -116,25 +117,25 @@ def test_poke_wildcard(self, mock_hook):
mock_check_for_wildcard_key = mock_hook.return_value.check_for_wildcard_key
mock_check_for_wildcard_key.return_value = False
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_check_for_wildcard_key.assert_called_once_with(op.bucket_key, op.bucket_name)
mock_check_for_wildcard_key.return_value = True
- self.assertTrue(op.poke(None))
+ assert op.poke(None)
class TestS3KeySizeSensor(unittest.TestCase):
@mock.patch('airflow.providers.amazon.aws.sensors.s3_key.S3Hook.check_for_key', return_value=False)
def test_poke_check_for_key_false(self, mock_check_for_key):
op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file')
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name)
@mock.patch('airflow.providers.amazon.aws.sensors.s3_key.S3KeySizeSensor.get_files', return_value=[])
@mock.patch('airflow.providers.amazon.aws.sensors.s3_key.S3Hook.check_for_key', return_value=True)
def test_poke_get_files_false(self, mock_check_for_key, mock_get_files):
op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file')
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name)
mock_get_files.assert_called_once_with(s3_hook=op.get_hook())
@@ -160,5 +161,5 @@ def test_poke(self, paginate_return_value, poke_return_value, mock_hook):
mock_conn.return_value.get_paginator.return_value = mock_paginator
mock_hook.return_value.get_conn = mock_conn
mock_paginator.paginate.return_value = [paginate_return_value]
- self.assertIs(op.poke(None), poke_return_value)
+ assert op.poke(None) is poke_return_value
mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name)
diff --git a/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py b/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py
index 91bdfab08cc91..9e1bbcc4d53dc 100644
--- a/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py
+++ b/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py
@@ -19,6 +19,7 @@
from datetime import datetime
from unittest import TestCase, mock
+import pytest
from freezegun import freeze_time
from parameterized import parameterized
@@ -51,7 +52,7 @@ def setUp(self):
)
def test_reschedule_mode_not_allowed(self):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
S3KeysUnchangedSensor(
task_id='sensor_2',
bucket_name='test-bucket',
@@ -77,7 +78,7 @@ def test_render_template_fields(self):
def test_files_deleted_between_pokes_throw_error(self):
self.sensor.allow_delete = False
self.sensor.is_keys_unchanged({'a', 'b'})
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.sensor.is_keys_unchanged({'a'})
@parameterized.expand(
@@ -94,17 +95,17 @@ def test_files_deleted_between_pokes_throw_error(self):
)
@freeze_time(DEFAULT_DATE, auto_tick_seconds=10)
def test_key_changes(self, current_objects, expected_returns, inactivity_periods):
- self.assertEqual(self.sensor.is_keys_unchanged(current_objects[0]), expected_returns[0])
- self.assertEqual(self.sensor.inactivity_seconds, inactivity_periods[0])
- self.assertEqual(self.sensor.is_keys_unchanged(current_objects[1]), expected_returns[1])
- self.assertEqual(self.sensor.inactivity_seconds, inactivity_periods[1])
- self.assertEqual(self.sensor.is_keys_unchanged(current_objects[2]), expected_returns[2])
- self.assertEqual(self.sensor.inactivity_seconds, inactivity_periods[2])
+ assert self.sensor.is_keys_unchanged(current_objects[0]) == expected_returns[0]
+ assert self.sensor.inactivity_seconds == inactivity_periods[0]
+ assert self.sensor.is_keys_unchanged(current_objects[1]) == expected_returns[1]
+ assert self.sensor.inactivity_seconds == inactivity_periods[1]
+ assert self.sensor.is_keys_unchanged(current_objects[2]) == expected_returns[2]
+ assert self.sensor.inactivity_seconds == inactivity_periods[2]
@freeze_time(DEFAULT_DATE, auto_tick_seconds=10)
@mock.patch('airflow.providers.amazon.aws.sensors.s3_keys_unchanged.S3Hook')
def test_poke_succeeds_on_upload_complete(self, mock_hook):
mock_hook.return_value.list_keys.return_value = {'a'}
- self.assertFalse(self.sensor.poke(dict()))
- self.assertFalse(self.sensor.poke(dict()))
- self.assertTrue(self.sensor.poke(dict()))
+ assert not self.sensor.poke(dict())
+ assert not self.sensor.poke(dict())
+ assert self.sensor.poke(dict())
diff --git a/tests/providers/amazon/aws/sensors/test_s3_prefix.py b/tests/providers/amazon/aws/sensors/test_s3_prefix.py
index a06d6f0186259..41e61671c7427 100644
--- a/tests/providers/amazon/aws/sensors/test_s3_prefix.py
+++ b/tests/providers/amazon/aws/sensors/test_s3_prefix.py
@@ -28,10 +28,10 @@ def test_poke(self, mock_hook):
op = S3PrefixSensor(task_id='s3_prefix', bucket_name='bucket', prefix='prefix')
mock_hook.return_value.check_for_prefix.return_value = False
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_hook.return_value.check_for_prefix.assert_called_once_with(
prefix='prefix', delimiter='/', bucket_name='bucket'
)
mock_hook.return_value.check_for_prefix.return_value = True
- self.assertTrue(op.poke(None))
+ assert op.poke(None)
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_base.py b/tests/providers/amazon/aws/sensors/test_sagemaker_base.py
index a52f7e53827e5..2acaef9e732e3 100644
--- a/tests/providers/amazon/aws/sensors/test_sagemaker_base.py
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_base.py
@@ -18,6 +18,8 @@
import unittest
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
@@ -57,7 +59,7 @@ def state_from_response(self, response):
sensor = SageMakerBaseSensorSubclass(task_id='test_task', poke_interval=2, aws_conn_id='aws_test')
- self.assertEqual(sensor.poke(None), False)
+ assert sensor.poke(None) is False
def test_poke_with_not_implemented_method(self):
class SageMakerBaseSensorSubclass(SageMakerBaseSensor):
@@ -69,7 +71,8 @@ def failed_states(self):
sensor = SageMakerBaseSensorSubclass(task_id='test_task', poke_interval=2, aws_conn_id='aws_test')
- self.assertRaises(NotImplementedError, sensor.poke, None)
+ with pytest.raises(NotImplementedError):
+ sensor.poke(None)
def test_poke_with_bad_response(self):
class SageMakerBaseSensorSubclass(SageMakerBaseSensor):
@@ -87,7 +90,7 @@ def state_from_response(self, response):
sensor = SageMakerBaseSensorSubclass(task_id='test_task', poke_interval=2, aws_conn_id='aws_test')
- self.assertEqual(sensor.poke(None), False)
+ assert sensor.poke(None) is False
def test_poke_with_job_failure(self):
class SageMakerBaseSensorSubclass(SageMakerBaseSensor):
@@ -105,4 +108,5 @@ def state_from_response(self, response):
sensor = SageMakerBaseSensorSubclass(task_id='test_task', poke_interval=2, aws_conn_id='aws_test')
- self.assertRaises(AirflowException, sensor.poke, None)
+ with pytest.raises(AirflowException):
+ sensor.poke(None)
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py
index ca4090b12f2dd..410d7ec1c4ba7 100644
--- a/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.sensors.sagemaker_endpoint import SageMakerEndpointSensor
@@ -60,7 +62,8 @@ def test_sensor_with_failure(self, mock_describe, mock_get_conn):
sensor = SageMakerEndpointSensor(
task_id='test_task', poke_interval=1, aws_conn_id='aws_test', endpoint_name='test_job_name'
)
- self.assertRaises(AirflowException, sensor.execute, None)
+ with pytest.raises(AirflowException):
+ sensor.execute(None)
mock_describe.assert_called_once_with('test_job_name')
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -81,7 +84,7 @@ def test_sensor(self, mock_describe, hook_init, mock_get_conn):
sensor.execute(None)
# make sure we called 3 times(terminated when its completed)
- self.assertEqual(mock_describe.call_count, 3)
+ assert mock_describe.call_count == 3
# make sure the hook was initialized with the specific params
calls = [mock.call(aws_conn_id='aws_test')]
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_training.py b/tests/providers/amazon/aws/sensors/test_sagemaker_training.py
index c6d5b782a68ed..09d98f762ed08 100644
--- a/tests/providers/amazon/aws/sensors/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_training.py
@@ -20,6 +20,8 @@
from datetime import datetime
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
@@ -60,7 +62,8 @@ def test_sensor_with_failure(self, mock_describe_job, hook_init, mock_client):
job_name='test_job_name',
print_log=False,
)
- self.assertRaises(AirflowException, sensor.execute, None)
+ with pytest.raises(AirflowException):
+ sensor.execute(None)
mock_describe_job.assert_called_once_with('test_job_name')
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -85,7 +88,7 @@ def test_sensor(self, mock_describe_job, hook_init, mock_client):
sensor.execute(None)
# make sure we called 3 times(terminated when its completed)
- self.assertEqual(mock_describe_job.call_count, 3)
+ assert mock_describe_job.call_count == 3
# make sure the hook was initialized with the specific params
calls = [mock.call(aws_conn_id='aws_test')]
@@ -117,8 +120,8 @@ def test_sensor_with_log(
sensor.execute(None)
- self.assertEqual(mock_describe_job_with_log.call_count, 3)
- self.assertEqual(mock_describe_job.call_count, 1)
+ assert mock_describe_job_with_log.call_count == 3
+ assert mock_describe_job.call_count == 1
calls = [mock.call(aws_conn_id='aws_test')]
hook_init.assert_has_calls(calls)
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py b/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py
index b0463373c10ac..a3e23d8a505c7 100644
--- a/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.sensors.sagemaker_transform import SageMakerTransformSensor
@@ -58,7 +60,8 @@ def test_sensor_with_failure(self, mock_describe_job, mock_client):
sensor = SageMakerTransformSensor(
task_id='test_task', poke_interval=2, aws_conn_id='aws_test', job_name='test_job_name'
)
- self.assertRaises(AirflowException, sensor.execute, None)
+ with pytest.raises(AirflowException):
+ sensor.execute(None)
mock_describe_job.assert_called_once_with('test_job_name')
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -79,7 +82,7 @@ def test_sensor(self, mock_describe_job, hook_init, mock_client):
sensor.execute(None)
# make sure we called 3 times(terminated when its completed)
- self.assertEqual(mock_describe_job.call_count, 3)
+ assert mock_describe_job.call_count == 3
# make sure the hook was initialized with the specific params
calls = [mock.call(aws_conn_id='aws_test')]
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py b/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py
index 32b2553f1f5c6..9b79b60028950 100644
--- a/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.sensors.sagemaker_tuning import SageMakerTuningSensor
@@ -61,7 +63,8 @@ def test_sensor_with_failure(self, mock_describe_job, mock_client):
sensor = SageMakerTuningSensor(
task_id='test_task', poke_interval=2, aws_conn_id='aws_test', job_name='test_job_name'
)
- self.assertRaises(AirflowException, sensor.execute, None)
+ with pytest.raises(AirflowException):
+ sensor.execute(None)
mock_describe_job.assert_called_once_with('test_job_name')
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -82,7 +85,7 @@ def test_sensor(self, mock_describe_job, hook_init, mock_client):
sensor.execute(None)
# make sure we called 3 times(terminated when its completed)
- self.assertEqual(mock_describe_job.call_count, 3)
+ assert mock_describe_job.call_count == 3
# make sure the hook was initialized with the specific params
calls = [mock.call(aws_conn_id='aws_test')]
diff --git a/tests/providers/amazon/aws/sensors/test_sqs.py b/tests/providers/amazon/aws/sensors/test_sqs.py
index cbd6639da3c5a..90349a321c1e9 100644
--- a/tests/providers/amazon/aws/sensors/test_sqs.py
+++ b/tests/providers/amazon/aws/sensors/test_sqs.py
@@ -20,6 +20,7 @@
import unittest
from unittest import mock
+import pytest
from moto import mock_sqs
from airflow.exceptions import AirflowException
@@ -49,23 +50,22 @@ def test_poke_success(self):
self.sqs_hook.send_message(queue_url='test', message_body='hello')
result = self.sensor.poke(self.mock_context)
- self.assertTrue(result)
+ assert result
- self.assertTrue(
- "'Body': 'hello'" in str(self.mock_context['ti'].method_calls),
- "context call should contain message hello",
- )
+ assert "'Body': 'hello'" in str(
+ self.mock_context['ti'].method_calls
+ ), "context call should contain message hello"
@mock_sqs
def test_poke_no_message_failed(self):
self.sqs_hook.create_queue('test')
result = self.sensor.poke(self.mock_context)
- self.assertFalse(result)
+ assert not result
context_calls = []
- self.assertTrue(self.mock_context['ti'].method_calls == context_calls, "context call should be same")
+ assert self.mock_context['ti'].method_calls == context_calls, "context call should be same"
@mock.patch.object(SQSHook, 'get_conn')
def test_poke_delete_raise_airflow_exception(self, mock_conn):
@@ -95,15 +95,15 @@ def test_poke_delete_raise_airflow_exception(self, mock_conn):
'Failed': [{'Id': '22f67273-4dbc-4c19-83b5-aee71bfeb832'}]
}
- with self.assertRaises(AirflowException) as context:
+ with pytest.raises(AirflowException) as ctx:
self.sensor.poke(self.mock_context)
- self.assertTrue('Delete SQS Messages failed' in context.exception.args[0])
+ assert 'Delete SQS Messages failed' in ctx.value.args[0]
@mock.patch.object(SQSHook, 'get_conn')
def test_poke_receive_raise_exception(self, mock_conn):
mock_conn.return_value.receive_message.side_effect = Exception('test exception')
- with self.assertRaises(Exception) as context:
+ with pytest.raises(Exception) as ctx:
self.sensor.poke(self.mock_context)
- self.assertTrue('test exception' in context.exception.args[0])
+ assert 'test exception' in ctx.value.args[0]
diff --git a/tests/providers/amazon/aws/sensors/test_step_function_execution.py b/tests/providers/amazon/aws/sensors/test_step_function_execution.py
index bbfffacbb9c5c..03235392f0335 100644
--- a/tests/providers/amazon/aws/sensors/test_step_function_execution.py
+++ b/tests/providers/amazon/aws/sensors/test_step_function_execution.py
@@ -20,6 +20,7 @@
from unittest import mock
from unittest.mock import MagicMock
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -43,10 +44,10 @@ def test_init(self):
task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME
)
- self.assertEqual(TASK_ID, sensor.task_id)
- self.assertEqual(EXECUTION_ARN, sensor.execution_arn)
- self.assertEqual(AWS_CONN_ID, sensor.aws_conn_id)
- self.assertEqual(REGION_NAME, sensor.region_name)
+ assert TASK_ID == sensor.task_id
+ assert EXECUTION_ARN == sensor.execution_arn
+ assert AWS_CONN_ID == sensor.aws_conn_id
+ assert REGION_NAME == sensor.region_name
@parameterized.expand([('FAILED',), ('TIMED_OUT',), ('ABORTED',)])
@mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook')
@@ -60,7 +61,7 @@ def test_exceptions(self, mock_status, mock_hook):
task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
sensor.poke(self.mock_context)
@mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook')
@@ -74,7 +75,7 @@ def test_running(self, mock_hook):
task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME
)
- self.assertFalse(sensor.poke(self.mock_context))
+ assert not sensor.poke(self.mock_context)
@mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook')
def test_succeeded(self, mock_hook):
@@ -87,4 +88,4 @@ def test_succeeded(self, mock_hook):
task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME
)
- self.assertTrue(sensor.poke(self.mock_context))
+ assert sensor.poke(self.mock_context)
diff --git a/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py b/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py
index 522f6ad1f5cac..63246d6b5c84a 100644
--- a/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py
@@ -62,4 +62,4 @@ def test_dynamodb_to_s3_success(self, mock_aws_dynamodb_hook, mock_s3_hook):
dynamodb_to_s3_operator.execute(context={})
- self.assertEqual([{'a': 1}, {'b': 2}, {'c': 3}], self.output_queue)
+ assert [{'a': 1}, {'b': 2}, {'c': 3}] == self.output_queue
diff --git a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
index fc62cc4635c19..eb13b3b2d704c 100644
--- a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
@@ -65,8 +65,8 @@ def test_execute_incremental(self, mock_hook, mock_hook2):
# we expect all except first file in MOCK_FILES to be uploaded
# and all the MOCK_FILES to be present at the S3 bucket
uploaded_files = operator.execute(None)
- self.assertEqual(sorted(MOCK_FILES[1:]), sorted(uploaded_files))
- self.assertEqual(sorted(MOCK_FILES), sorted(hook.list_keys('bucket', delimiter='/')))
+ assert sorted(MOCK_FILES[1:]) == sorted(uploaded_files)
+ assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/'))
# Test2: All the files are already in origin and destination without replace
@mock_s3
@@ -96,8 +96,8 @@ def test_execute_without_replace(self, mock_hook, mock_hook2):
# we expect nothing to be uploaded
# and all the MOCK_FILES to be present at the S3 bucket
uploaded_files = operator.execute(None)
- self.assertEqual([], uploaded_files)
- self.assertEqual(sorted(MOCK_FILES), sorted(hook.list_keys('bucket', delimiter='/')))
+ assert [] == uploaded_files
+ assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/'))
# Test3: There are no files in destination bucket
@mock_s3
@@ -125,8 +125,8 @@ def test_execute(self, mock_hook, mock_hook2):
# we expect all MOCK_FILES to be uploaded
# and all MOCK_FILES to be present at the S3 bucket
uploaded_files = operator.execute(None)
- self.assertEqual(sorted(MOCK_FILES), sorted(uploaded_files))
- self.assertEqual(sorted(MOCK_FILES), sorted(hook.list_keys('bucket', delimiter='/')))
+ assert sorted(MOCK_FILES) == sorted(uploaded_files)
+ assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/'))
# Test4: Destination and Origin are in sync but replace all files in destination
@mock_s3
@@ -156,8 +156,8 @@ def test_execute_with_replace(self, mock_hook, mock_hook2):
# we expect all MOCK_FILES to be uploaded and replace the existing ones
# and all MOCK_FILES to be present at the S3 bucket
uploaded_files = operator.execute(None)
- self.assertEqual(sorted(MOCK_FILES), sorted(uploaded_files))
- self.assertEqual(sorted(MOCK_FILES), sorted(hook.list_keys('bucket', delimiter='/')))
+ assert sorted(MOCK_FILES) == sorted(uploaded_files)
+ assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/'))
# Test5: Incremental sync with replace
@mock_s3
@@ -187,8 +187,8 @@ def test_execute_incremental_with_replace(self, mock_hook, mock_hook2):
# we expect all the MOCK_FILES to be uploaded and replace the existing ones
# and all MOCK_FILES to be present at the S3 bucket
uploaded_files = operator.execute(None)
- self.assertEqual(sorted(MOCK_FILES), sorted(uploaded_files))
- self.assertEqual(sorted(MOCK_FILES), sorted(hook.list_keys('bucket', delimiter='/')))
+ assert sorted(MOCK_FILES) == sorted(uploaded_files)
+ assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/'))
@mock_s3
@mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook')
@@ -271,4 +271,4 @@ def test_execute_with_s3_acl_policy(self, mock_load_bytes, mock_gcs_hook, mock_g
# Make sure the acl_policy parameter is passed to the upload method
_, kwargs = mock_load_bytes.call_args
- self.assertEqual(kwargs['acl_policy'], S3_ACL_POLICY)
+ assert kwargs['acl_policy'] == S3_ACL_POLICY
diff --git a/tests/providers/amazon/aws/transfers/test_google_api_to_s3.py b/tests/providers/amazon/aws/transfers/test_google_api_to_s3.py
index 44d8e3f6d404a..446b33cd1c635 100644
--- a/tests/providers/amazon/aws/transfers/test_google_api_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_google_api_to_s3.py
@@ -19,6 +19,8 @@
import unittest
from unittest.mock import Mock, patch
+import pytest
+
from airflow import models
from airflow.configuration import load_test_config
from airflow.models.xcom import MAX_XCOM_SIZE
@@ -139,7 +141,8 @@ def test_execute_with_xcom_exceeded_max_xcom_size(
}
context['task_instance'].xcom_pull.return_value = {}
- self.assertRaises(RuntimeError, GoogleApiToS3Operator(**self.kwargs, **xcom_kwargs).execute, context)
+ with pytest.raises(RuntimeError):
+ GoogleApiToS3Operator(**self.kwargs, **xcom_kwargs).execute(context)
mock_google_api_hook_query.assert_called_once_with(
endpoint=self.kwargs['google_api_endpoint_path'],
diff --git a/tests/providers/amazon/aws/transfers/test_hive_to_dynamodb.py b/tests/providers/amazon/aws/transfers/test_hive_to_dynamodb.py
index 9a869c1af81a3..06e3f263c9128 100644
--- a/tests/providers/amazon/aws/transfers/test_hive_to_dynamodb.py
+++ b/tests/providers/amazon/aws/transfers/test_hive_to_dynamodb.py
@@ -54,7 +54,7 @@ def process_data(data, *args, **kwargs):
@mock_dynamodb2
def test_get_conn_returns_a_boto3_connection(self):
hook = AwsDynamoDBHook(aws_conn_id='aws_default')
- self.assertIsNotNone(hook.get_conn())
+ assert hook.get_conn() is not None
@mock.patch(
'airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_pandas_df',
@@ -85,7 +85,7 @@ def test_get_records_with_schema(self, mock_get_pandas_df):
table = self.hook.get_conn().Table('test_airflow')
table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow')
- self.assertEqual(table.item_count, 1)
+ assert table.item_count == 1
@mock.patch(
'airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_pandas_df',
@@ -117,4 +117,4 @@ def test_pre_process_records_with_schema(self, mock_get_pandas_df):
table = self.hook.get_conn().Table('test_airflow')
table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow')
- self.assertEqual(table.item_count, 1)
+ assert table.item_count == 1
diff --git a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
index 57a5a8d5a0332..746b1dd0504c6 100644
--- a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
@@ -58,17 +58,17 @@ def setUp(self):
)
def test_init(self):
- self.assertEqual(self.mock_operator.task_id, TASK_ID)
- self.assertEqual(self.mock_operator.mongo_conn_id, MONGO_CONN_ID)
- self.assertEqual(self.mock_operator.s3_conn_id, S3_CONN_ID)
- self.assertEqual(self.mock_operator.mongo_collection, MONGO_COLLECTION)
- self.assertEqual(self.mock_operator.mongo_query, MONGO_QUERY)
- self.assertEqual(self.mock_operator.s3_bucket, S3_BUCKET)
- self.assertEqual(self.mock_operator.s3_key, S3_KEY)
- self.assertEqual(self.mock_operator.compression, COMPRESSION)
+ assert self.mock_operator.task_id == TASK_ID
+ assert self.mock_operator.mongo_conn_id == MONGO_CONN_ID
+ assert self.mock_operator.s3_conn_id == S3_CONN_ID
+ assert self.mock_operator.mongo_collection == MONGO_COLLECTION
+ assert self.mock_operator.mongo_query == MONGO_QUERY
+ assert self.mock_operator.s3_bucket == S3_BUCKET
+ assert self.mock_operator.s3_key == S3_KEY
+ assert self.mock_operator.compression == COMPRESSION
def test_template_field_overrides(self):
- self.assertEqual(self.mock_operator.template_fields, ['s3_key', 'mongo_query', 'mongo_collection'])
+ assert self.mock_operator.template_fields == ['s3_key', 'mongo_query', 'mongo_collection']
def test_render_template(self):
ti = TaskInstance(self.mock_operator, DEFAULT_DATE)
@@ -76,7 +76,7 @@ def test_render_template(self):
expected_rendered_template = {'$lt': '2017-01-01T00:00:00+00:00Z'}
- self.assertDictEqual(expected_rendered_template, getattr(self.mock_operator, 'mongo_query'))
+ assert expected_rendered_template == getattr(self.mock_operator, 'mongo_query')
@mock.patch('airflow.providers.amazon.aws.transfers.mongo_to_s3.MongoHook')
@mock.patch('airflow.providers.amazon.aws.transfers.mongo_to_s3.S3Hook')
diff --git a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
index 4545418055228..8db3becfb2259 100644
--- a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
@@ -89,6 +89,10 @@ def test_execute(
assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], unload_query)
def test_template_fields_overrides(self):
- self.assertEqual(
- RedshiftToS3Operator.template_fields, ('s3_bucket', 's3_key', 'schema', 'table', 'unload_options')
+ assert RedshiftToS3Operator.template_fields == (
+ 's3_bucket',
+ 's3_key',
+ 'schema',
+ 'table',
+ 'unload_options',
)
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
index ed562d4d5ce04..c59d15e7b1e24 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
@@ -111,6 +111,10 @@ def test_truncate(self, mock_run, mock_session):
assert mock_run.call_count == 1
def test_template_fields_overrides(self):
- self.assertEqual(
- S3ToRedshiftOperator.template_fields, ('s3_bucket', 's3_key', 'schema', 'table', 'copy_options')
+ assert S3ToRedshiftOperator.template_fields == (
+ 's3_bucket',
+ 's3_key',
+ 'schema',
+ 'table',
+ 'copy_options',
)
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py b/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py
index 4108f88923806..539230c965293 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_sftp.py
@@ -81,7 +81,7 @@ def test_s3_to_sftp_operation(self):
# Test for creation of s3 bucket
conn = boto3.client('s3')
conn.create_bucket(Bucket=self.s3_bucket)
- self.assertTrue(self.s3_hook.check_for_bucket(self.s3_bucket))
+ assert self.s3_hook.check_for_bucket(self.s3_bucket)
with open(LOCAL_FILE_PATH, 'w') as file:
file.write(test_remote_file_content)
@@ -90,10 +90,10 @@ def test_s3_to_sftp_operation(self):
# Check if object was created in s3
objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket, Prefix=self.s3_key)
# there should be object found, and there should only be one object found
- self.assertEqual(len(objects_in_dest_bucket['Contents']), 1)
+ assert len(objects_in_dest_bucket['Contents']) == 1
# the object found should be consistent with dest_key specified earlier
- self.assertEqual(objects_in_dest_bucket['Contents'][0]['Key'], self.s3_key)
+ assert objects_in_dest_bucket['Contents'][0]['Key'] == self.s3_key
# get remote file to local
run_task = S3ToSFTPOperator(
@@ -105,7 +105,7 @@ def test_s3_to_sftp_operation(self):
task_id=TASK_ID,
dag=self.dag,
)
- self.assertIsNotNone(run_task)
+ assert run_task is not None
run_task.execute(None)
@@ -117,18 +117,17 @@ def test_s3_to_sftp_operation(self):
do_xcom_push=True,
dag=self.dag,
)
- self.assertIsNotNone(check_file_task)
+ assert check_file_task is not None
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
ti3.run()
- self.assertEqual(
- ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
- test_remote_file_content.encode('utf-8'),
- )
+ assert ti3.xcom_pull(
+ task_ids='test_check_file', key='return_value'
+ ).strip() == test_remote_file_content.encode('utf-8')
# Clean up after finishing with test
conn.delete_object(Bucket=self.s3_bucket, Key=self.s3_key)
conn.delete_bucket(Bucket=self.s3_bucket)
- self.assertFalse(self.s3_hook.check_for_bucket(self.s3_bucket))
+ assert not self.s3_hook.check_for_bucket(self.s3_bucket)
def delete_remote_resource(self):
# check the remote file content
@@ -139,7 +138,7 @@ def delete_remote_resource(self):
do_xcom_push=True,
dag=self.dag,
)
- self.assertIsNotNone(remove_file_task)
+ assert remove_file_task is not None
ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow())
ti3.run()
diff --git a/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py b/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py
index c9767c3b8b23b..30267d46eecb9 100644
--- a/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_sftp_to_s3.py
@@ -85,14 +85,14 @@ def test_sftp_to_s3_operation(self):
do_xcom_push=True,
dag=self.dag,
)
- self.assertIsNotNone(create_file_task)
+ assert create_file_task is not None
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
ti1.run()
# Test for creation of s3 bucket
conn = boto3.client('s3')
conn.create_bucket(Bucket=self.s3_bucket)
- self.assertTrue(self.s3_hook.check_for_bucket(self.s3_bucket))
+ assert self.s3_hook.check_for_bucket(self.s3_bucket)
# get remote file to local
run_task = SFTPToS3Operator(
@@ -104,19 +104,19 @@ def test_sftp_to_s3_operation(self):
task_id='test_sftp_to_s3',
dag=self.dag,
)
- self.assertIsNotNone(run_task)
+ assert run_task is not None
run_task.execute(None)
# Check if object was created in s3
objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket, Prefix=self.s3_key)
# there should be object found, and there should only be one object found
- self.assertEqual(len(objects_in_dest_bucket['Contents']), 1)
+ assert len(objects_in_dest_bucket['Contents']) == 1
# the object found should be consistent with dest_key specified earlier
- self.assertEqual(objects_in_dest_bucket['Contents'][0]['Key'], self.s3_key)
+ assert objects_in_dest_bucket['Contents'][0]['Key'] == self.s3_key
# Clean up after finishing with test
conn.delete_object(Bucket=self.s3_bucket, Key=self.s3_key)
conn.delete_bucket(Bucket=self.s3_bucket)
- self.assertFalse(self.s3_hook.check_for_bucket(self.s3_bucket))
+ assert not self.s3_hook.check_for_bucket(self.s3_bucket)
diff --git a/tests/providers/apache/cassandra/hooks/test_cassandra.py b/tests/providers/apache/cassandra/hooks/test_cassandra.py
index 5e37acef674ec..575d6b1920b39 100644
--- a/tests/providers/apache/cassandra/hooks/test_cassandra.py
+++ b/tests/providers/apache/cassandra/hooks/test_cassandra.py
@@ -82,10 +82,10 @@ def test_get_conn(self):
mock_connect.assert_called_once_with('test_keyspace')
cluster = hook.get_cluster()
- self.assertEqual(cluster.contact_points, ['host-1', 'host-2'])
- self.assertEqual(cluster.port, 9042)
- self.assertEqual(cluster.protocol_version, 4)
- self.assertTrue(isinstance(cluster.load_balancing_policy, TokenAwarePolicy))
+ assert cluster.contact_points == ['host-1', 'host-2']
+ assert cluster.port == 9042
+ assert cluster.protocol_version == 4
+ assert isinstance(cluster.load_balancing_policy, TokenAwarePolicy)
def test_get_lb_policy_with_no_args(self):
# test LB policies with no args
@@ -163,12 +163,12 @@ def _assert_get_lb_policy(
thrown = False
try:
policy = CassandraHook.get_lb_policy(policy_name, policy_args)
- self.assertTrue(isinstance(policy, expected_policy_type))
+ assert isinstance(policy, expected_policy_type)
if expected_child_policy_type:
- self.assertTrue(isinstance(policy._child_policy, expected_child_policy_type))
+ assert isinstance(policy._child_policy, expected_child_policy_type)
except Exception: # pylint: disable=broad-except
thrown = True
- self.assertEqual(should_throw, thrown)
+ assert should_throw == thrown
def test_record_exists_with_keyspace_from_cql(self):
hook = CassandraHook("cassandra_default")
@@ -181,8 +181,8 @@ def test_record_exists_with_keyspace_from_cql(self):
for cql in cqls:
session.execute(cql)
- self.assertTrue(hook.record_exists("s.t", {"pk1": "foo", "pk2": "bar"}))
- self.assertFalse(hook.record_exists("s.t", {"pk1": "foo", "pk2": "baz"}))
+ assert hook.record_exists("s.t", {"pk1": "foo", "pk2": "bar"})
+ assert not hook.record_exists("s.t", {"pk1": "foo", "pk2": "baz"})
session.shutdown()
hook.shutdown_cluster()
@@ -198,8 +198,8 @@ def test_record_exists_with_keyspace_from_session(self):
for cql in cqls:
session.execute(cql)
- self.assertTrue(hook.record_exists("t", {"pk1": "foo", "pk2": "bar"}))
- self.assertFalse(hook.record_exists("t", {"pk1": "foo", "pk2": "baz"}))
+ assert hook.record_exists("t", {"pk1": "foo", "pk2": "bar"})
+ assert not hook.record_exists("t", {"pk1": "foo", "pk2": "baz"})
session.shutdown()
hook.shutdown_cluster()
@@ -214,8 +214,8 @@ def test_table_exists_with_keyspace_from_cql(self):
for cql in cqls:
session.execute(cql)
- self.assertTrue(hook.table_exists("s.t"))
- self.assertFalse(hook.table_exists("s.u"))
+ assert hook.table_exists("s.t")
+ assert not hook.table_exists("s.u")
session.shutdown()
hook.shutdown_cluster()
@@ -230,8 +230,8 @@ def test_table_exists_with_keyspace_from_session(self):
for cql in cqls:
session.execute(cql)
- self.assertTrue(hook.table_exists("t"))
- self.assertFalse(hook.table_exists("u"))
+ assert hook.table_exists("t")
+ assert not hook.table_exists("u")
session.shutdown()
hook.shutdown_cluster()
diff --git a/tests/providers/apache/cassandra/sensors/test_record.py b/tests/providers/apache/cassandra/sensors/test_record.py
index 35f5aefe4dc29..8785839a1b39f 100644
--- a/tests/providers/apache/cassandra/sensors/test_record.py
+++ b/tests/providers/apache/cassandra/sensors/test_record.py
@@ -37,7 +37,7 @@ def test_poke(self, mock_hook):
)
exists = sensor.poke(dict())
- self.assertTrue(exists)
+ assert exists
mock_hook.return_value.record_exists.assert_called_once_with(TEST_CASSANDRA_TABLE, TEST_CASSANDRA_KEY)
mock_hook.assert_called_once_with(TEST_CASSANDRA_CONN_ID)
@@ -52,7 +52,7 @@ def test_poke_should_not_fail_with_empty_keys(self, mock_hook):
)
exists = sensor.poke(dict())
- self.assertTrue(exists)
+ assert exists
mock_hook.return_value.record_exists.assert_called_once_with(TEST_CASSANDRA_TABLE, None)
mock_hook.assert_called_once_with(TEST_CASSANDRA_CONN_ID)
@@ -69,7 +69,7 @@ def test_poke_should_return_false_for_non_existing_table(self, mock_hook):
)
exists = sensor.poke(dict())
- self.assertFalse(exists)
+ assert not exists
mock_hook.return_value.record_exists.assert_called_once_with(TEST_CASSANDRA_TABLE, TEST_CASSANDRA_KEY)
mock_hook.assert_called_once_with(TEST_CASSANDRA_CONN_ID)
diff --git a/tests/providers/apache/cassandra/sensors/test_table.py b/tests/providers/apache/cassandra/sensors/test_table.py
index 4f35bac4b9dcb..8da8c62c8f459 100644
--- a/tests/providers/apache/cassandra/sensors/test_table.py
+++ b/tests/providers/apache/cassandra/sensors/test_table.py
@@ -36,7 +36,7 @@ def test_poke(self, mock_hook):
)
exists = sensor.poke(dict())
- self.assertTrue(exists)
+ assert exists
mock_hook.return_value.table_exists.assert_called_once_with(TEST_CASSANDRA_TABLE)
mock_hook.assert_called_once_with(TEST_CASSANDRA_CONN_ID)
@@ -52,7 +52,7 @@ def test_poke_should_return_false_for_non_existing_table(self, mock_hook):
)
exists = sensor.poke(dict())
- self.assertFalse(exists)
+ assert not exists
mock_hook.return_value.table_exists.assert_called_once_with(TEST_CASSANDRA_TABLE)
mock_hook.assert_called_once_with(TEST_CASSANDRA_CONN_ID)
@@ -66,7 +66,7 @@ def test_poke_should_succeed_for_table_with_mentioned_keyspace(self, mock_hook):
)
exists = sensor.poke(dict())
- self.assertTrue(exists)
+ assert exists
mock_hook.return_value.table_exists.assert_called_once_with(TEST_CASSANDRA_TABLE_WITH_KEYSPACE)
mock_hook.assert_called_once_with(TEST_CASSANDRA_CONN_ID)
diff --git a/tests/providers/apache/druid/hooks/test_druid.py b/tests/providers/apache/druid/hooks/test_druid.py
index 8954f4746d0ff..afd7c1c1d8276 100644
--- a/tests/providers/apache/druid/hooks/test_druid.py
+++ b/tests/providers/apache/druid/hooks/test_druid.py
@@ -20,6 +20,7 @@
import unittest
from unittest.mock import MagicMock, patch
+import pytest
import requests
import requests_mock
@@ -52,11 +53,11 @@ def test_submit_gone_wrong(self, m):
)
# The job failed for some reason
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.db_hook.submit_indexing_job('Long json file')
- self.assertTrue(task_post.called_once)
- self.assertTrue(status_check.called_once)
+ assert task_post.called_once
+ assert status_check.called_once
@requests_mock.mock()
def test_submit_ok(self, m):
@@ -72,8 +73,8 @@ def test_submit_ok(self, m):
# Exists just as it should
self.db_hook.submit_indexing_job('Long json file')
- self.assertTrue(task_post.called_once)
- self.assertTrue(status_check.called_once)
+ assert task_post.called_once
+ assert status_check.called_once
@requests_mock.mock()
def test_submit_correct_json_body(self, m):
@@ -93,11 +94,11 @@ def test_submit_correct_json_body(self, m):
"""
self.db_hook.submit_indexing_job(json_ingestion_string)
- self.assertTrue(task_post.called_once)
- self.assertTrue(status_check.called_once)
+ assert task_post.called_once
+ assert status_check.called_once
if task_post.called_once:
req_body = task_post.request_history[0].json()
- self.assertEqual(req_body['task'], "9f8a7359-77d4-4612-b0cd-cc2f6a3c28de")
+ assert req_body['task'] == "9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"
@requests_mock.mock()
def test_submit_unknown_response(self, m):
@@ -111,11 +112,11 @@ def test_submit_unknown_response(self, m):
)
# An unknown error code
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.db_hook.submit_indexing_job('Long json file')
- self.assertTrue(task_post.called_once)
- self.assertTrue(status_check.called_once)
+ assert task_post.called_once
+ assert status_check.called_once
@requests_mock.mock()
def test_submit_timeout(self, m):
@@ -136,12 +137,12 @@ def test_submit_timeout(self, m):
)
# Because the jobs keeps running
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.db_hook.submit_indexing_job('Long json file')
- self.assertTrue(task_post.called_once)
- self.assertTrue(status_check.called)
- self.assertTrue(shutdown_post.called_once)
+ assert task_post.called_once
+ assert status_check.called
+ assert shutdown_post.called_once
@patch('airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection')
def test_get_conn_url(self, mock_get_connection):
@@ -152,7 +153,7 @@ def test_get_conn_url(self, mock_get_connection):
get_conn_value.extra_dejson = {'endpoint': 'ingest'}
mock_get_connection.return_value = get_conn_value
hook = DruidHook(timeout=1, max_ingestion_time=5)
- self.assertEqual(hook.get_conn_url(), 'https://test_host:1/ingest')
+ assert hook.get_conn_url() == 'https://test_host:1/ingest'
@patch('airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection')
def test_get_auth(self, mock_get_connection):
@@ -161,7 +162,7 @@ def test_get_auth(self, mock_get_connection):
get_conn_value.password = 'password'
mock_get_connection.return_value = get_conn_value
expected = requests.auth.HTTPBasicAuth('airflow', 'password')
- self.assertEqual(self.db_hook.get_auth(), expected)
+ assert self.db_hook.get_auth() == expected
@patch('airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection')
def test_get_auth_with_no_user(self, mock_get_connection):
@@ -169,7 +170,7 @@ def test_get_auth_with_no_user(self, mock_get_connection):
get_conn_value.login = None
get_conn_value.password = 'password'
mock_get_connection.return_value = get_conn_value
- self.assertEqual(self.db_hook.get_auth(), None)
+ assert self.db_hook.get_auth() is None
@patch('airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection')
def test_get_auth_with_no_password(self, mock_get_connection):
@@ -177,7 +178,7 @@ def test_get_auth_with_no_password(self, mock_get_connection):
get_conn_value.login = 'airflow'
get_conn_value.password = None
mock_get_connection.return_value = get_conn_value
- self.assertEqual(self.db_hook.get_auth(), None)
+ assert self.db_hook.get_auth() is None
@patch('airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection')
def test_get_auth_with_no_user_and_password(self, mock_get_connection):
@@ -185,7 +186,7 @@ def test_get_auth_with_no_user_and_password(self, mock_get_connection):
get_conn_value.login = None
get_conn_value.password = None
mock_get_connection.return_value = get_conn_value
- self.assertEqual(self.db_hook.get_auth(), None)
+ assert self.db_hook.get_auth() is None
class TestDruidDbApiHook(unittest.TestCase):
@@ -210,14 +211,14 @@ def get_connection(self, conn_id):
def test_get_uri(self):
db_hook = self.db_hook()
- self.assertEqual('druid://host:1000/druid/v2/sql', db_hook.get_uri())
+ assert 'druid://host:1000/druid/v2/sql' == db_hook.get_uri()
def test_get_first_record(self):
statement = 'SQL'
result_sets = [('row1',), ('row2',)]
self.cur.fetchone.return_value = result_sets[0]
- self.assertEqual(result_sets[0], self.db_hook().get_first(statement))
+ assert result_sets[0] == self.db_hook().get_first(statement)
assert self.conn.close.call_count == 1
assert self.cur.close.call_count == 1
self.cur.execute.assert_called_once_with(statement)
@@ -227,7 +228,7 @@ def test_get_records(self):
result_sets = [('row1',), ('row2',)]
self.cur.fetchall.return_value = result_sets
- self.assertEqual(result_sets, self.db_hook().get_records(statement))
+ assert result_sets == self.db_hook().get_records(statement)
assert self.conn.close.call_count == 1
assert self.cur.close.call_count == 1
self.cur.execute.assert_called_once_with(statement)
@@ -240,9 +241,9 @@ def test_get_pandas_df(self):
self.cur.fetchall.return_value = result_sets
df = self.db_hook().get_pandas_df(statement)
- self.assertEqual(column, df.columns[0])
+ assert column == df.columns[0]
for i in range(len(result_sets)): # pylint: disable=consider-using-enumerate
- self.assertEqual(result_sets[i][0], df.values.tolist()[i][0])
+ assert result_sets[i][0] == df.values.tolist()[i][0]
assert self.conn.close.call_count == 1
assert self.cur.close.call_count == 1
self.cur.execute.assert_called_once_with(statement)
diff --git a/tests/providers/apache/druid/operators/test_druid.py b/tests/providers/apache/druid/operators/test_druid.py
index 9e563e4c07077..cbdfce5c206e7 100644
--- a/tests/providers/apache/druid/operators/test_druid.py
+++ b/tests/providers/apache/druid/operators/test_druid.py
@@ -67,4 +67,4 @@ def test_render_template(self):
}
}
'''
- self.assertEqual(expected, getattr(operator, 'json_index_file'))
+ assert expected == getattr(operator, 'json_index_file')
diff --git a/tests/providers/apache/druid/operators/test_druid_check.py b/tests/providers/apache/druid/operators/test_druid_check.py
index bb84e9d47daec..0e895dd69b927 100644
--- a/tests/providers/apache/druid/operators/test_druid_check.py
+++ b/tests/providers/apache/druid/operators/test_druid_check.py
@@ -21,6 +21,8 @@
from datetime import datetime
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import DAG
from airflow.providers.apache.druid.operators.druid_check import DruidCheckOperator
@@ -59,5 +61,5 @@ def test_execute_fail(self, mock_get_first):
operator = self.__construct_operator(sql)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
operator.execute()
diff --git a/tests/providers/apache/druid/transfers/test_hive_to_druid.py b/tests/providers/apache/druid/transfers/test_hive_to_druid.py
index 968dd1d78f90a..944aa157f1a5a 100644
--- a/tests/providers/apache/druid/transfers/test_hive_to_druid.py
+++ b/tests/providers/apache/druid/transfers/test_hive_to_druid.py
@@ -116,4 +116,4 @@ def test_construct_ingest_query(self):
}
# Make sure it is like we expect it
- self.assertEqual(provided_index_spec, expected_index_spec)
+ assert provided_index_spec == expected_index_spec
diff --git a/tests/providers/apache/hdfs/hooks/test_hdfs.py b/tests/providers/apache/hdfs/hooks/test_hdfs.py
index 032b8b82cee83..3dd42094fb8c5 100644
--- a/tests/providers/apache/hdfs/hooks/test_hdfs.py
+++ b/tests/providers/apache/hdfs/hooks/test_hdfs.py
@@ -43,10 +43,10 @@ class TestHDFSHook(unittest.TestCase):
)
def test_get_client(self):
client = HDFSHook(proxy_user='foo').get_conn()
- self.assertIsInstance(client, snakebite.client.Client)
- self.assertEqual('localhost', client.host)
- self.assertEqual(8020, client.port)
- self.assertEqual('foo', client.service.channel.effective_user)
+ assert isinstance(client, snakebite.client.Client)
+ assert 'localhost' == client.host
+ assert 8020 == client.port
+ assert 'foo' == client.service.channel.effective_user
@mock.patch.dict(
'os.environ',
@@ -86,4 +86,4 @@ def test_get_ha_client(self, mock_get_connections):
conn_2 = Connection(conn_id='hdfs_default', conn_type='hdfs', host='localhost2', port=8020)
mock_get_connections.return_value = [conn_1, conn_2]
client = HDFSHook().get_conn()
- self.assertIsInstance(client, snakebite.client.HAClient)
+ assert isinstance(client, snakebite.client.HAClient)
diff --git a/tests/providers/apache/hdfs/hooks/test_webhdfs.py b/tests/providers/apache/hdfs/hooks/test_webhdfs.py
index 012d3548cd505..dee8e11ccd8af 100644
--- a/tests/providers/apache/hdfs/hooks/test_webhdfs.py
+++ b/tests/providers/apache/hdfs/hooks/test_webhdfs.py
@@ -19,6 +19,7 @@
import unittest
from unittest.mock import call, patch
+import pytest
from hdfs import HdfsError
from airflow.models.connection import Connection
@@ -50,7 +51,7 @@ def test_get_conn(self, socket_mock, mock_get_connections, mock_insecure_client)
]
)
mock_insecure_client.return_value.status.assert_called_once_with('/')
- self.assertEqual(conn, mock_insecure_client.return_value)
+ assert conn == mock_insecure_client.return_value
@patch('airflow.providers.apache.hdfs.hooks.webhdfs.KerberosClient', create=True)
@patch(
@@ -67,11 +68,11 @@ def test_get_conn_kerberos_security_mode(
connection = mock_get_connections.return_value[0]
mock_kerberos_client.assert_called_once_with(f'http://{connection.host}:{connection.port}')
- self.assertEqual(conn, mock_kerberos_client.return_value)
+ assert conn == mock_kerberos_client.return_value
@patch('airflow.providers.apache.hdfs.hooks.webhdfs.WebHDFSHook._find_valid_server', return_value=None)
def test_get_conn_no_connection_found(self, mock_get_connection):
- with self.assertRaises(AirflowWebHDFSHookException):
+ with pytest.raises(AirflowWebHDFSHookException):
self.webhdfs_hook.get_conn()
@patch('airflow.providers.apache.hdfs.hooks.webhdfs.WebHDFSHook.get_conn')
@@ -83,7 +84,7 @@ def test_check_for_path(self, mock_get_conn):
mock_get_conn.assert_called_once_with()
mock_status = mock_get_conn.return_value.status
mock_status.assert_called_once_with(hdfs_path, strict=False)
- self.assertEqual(exists_path, bool(mock_status.return_value))
+ assert exists_path == bool(mock_status.return_value)
@patch('airflow.providers.apache.hdfs.hooks.webhdfs.WebHDFSHook.get_conn')
def test_load_file(self, mock_get_conn):
@@ -100,8 +101,8 @@ def test_load_file(self, mock_get_conn):
def test_simple_init(self):
hook = WebHDFSHook()
- self.assertIsNone(hook.proxy_user)
+ assert hook.proxy_user is None
def test_init_proxy_user(self):
hook = WebHDFSHook(proxy_user='someone')
- self.assertEqual('someone', hook.proxy_user)
+ assert 'someone' == hook.proxy_user
diff --git a/tests/providers/apache/hdfs/sensors/test_hdfs.py b/tests/providers/apache/hdfs/sensors/test_hdfs.py
index 4d25cad191f9a..c6630cfead2ad 100644
--- a/tests/providers/apache/hdfs/sensors/test_hdfs.py
+++ b/tests/providers/apache/hdfs/sensors/test_hdfs.py
@@ -20,6 +20,8 @@
import unittest
from datetime import timedelta
+import pytest
+
from airflow.exceptions import AirflowSensorTimeout
from airflow.providers.apache.hdfs.sensors.hdfs import HdfsFolderSensor, HdfsRegexSensor, HdfsSensor
from airflow.utils.timezone import datetime
@@ -70,7 +72,7 @@ def test_legacy_file_exist_but_filesize(self):
# When
# Then
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task.execute(None)
def test_legacy_file_does_not_exists(self):
@@ -89,7 +91,7 @@ def test_legacy_file_does_not_exists(self):
# When
# Then
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task.execute(None)
@@ -145,7 +147,7 @@ def test_should_be_empty_directory_fail(self):
# When
# Then
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task.execute(None)
def test_should_be_a_non_empty_directory(self):
@@ -192,7 +194,7 @@ def test_should_be_non_empty_directory_fail(self):
# When
# Then
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task.execute(None)
@@ -250,7 +252,7 @@ def test_should_not_match_regex(self):
# When
# Then
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task.execute(None)
def test_should_match_regex_and_filesize(self):
@@ -305,7 +307,7 @@ def test_should_match_regex_but_filesize(self):
# When
# Then
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task.execute(None)
def test_should_match_regex_but_copyingext(self):
@@ -332,5 +334,5 @@ def test_should_match_regex_but_copyingext(self):
# When
# Then
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task.execute(None)
diff --git a/tests/providers/apache/hdfs/sensors/test_web_hdfs.py b/tests/providers/apache/hdfs/sensors/test_web_hdfs.py
index b05d7b496c7dd..9af9442c334f3 100644
--- a/tests/providers/apache/hdfs/sensors/test_web_hdfs.py
+++ b/tests/providers/apache/hdfs/sensors/test_web_hdfs.py
@@ -35,7 +35,7 @@ def test_poke(self, mock_hook):
)
exists = sensor.poke(dict())
- self.assertTrue(exists)
+ assert exists
mock_hook.return_value.check_for_path.assert_called_once_with(hdfs_path=TEST_HDFS_PATH)
mock_hook.assert_called_once_with(TEST_HDFS_CONN)
@@ -51,7 +51,7 @@ def test_poke_should_return_false_for_non_existing_table(self, mock_hook):
)
exists = sensor.poke(dict())
- self.assertFalse(exists)
+ assert not exists
mock_hook.return_value.check_for_path.assert_called_once_with(hdfs_path=TEST_HDFS_PATH)
mock_hook.assert_called_once_with(TEST_HDFS_CONN)
diff --git a/tests/providers/apache/hive/hooks/test_hive.py b/tests/providers/apache/hive/hooks/test_hive.py
index abfcc2059236e..ce7d26c49a846 100644
--- a/tests/providers/apache/hive/hooks/test_hive.py
+++ b/tests/providers/apache/hive/hooks/test_hive.py
@@ -25,6 +25,7 @@
from unittest import mock
import pandas as pd
+import pytest
from hmsclient import HMSClient
from airflow.exceptions import AirflowException
@@ -199,17 +200,17 @@ def test_run_cli_with_hive_conf(self, mock_popen):
output = hook.run_cli(hql=hql, hive_conf={'key': 'value'})
process_inputs = " ".join(mock_popen.call_args_list[0][0][0])
- self.assertIn('value', process_inputs)
- self.assertIn('test_dag_id', process_inputs)
- self.assertIn('test_task_id', process_inputs)
- self.assertIn('test_execution_date', process_inputs)
- self.assertIn('test_dag_run_id', process_inputs)
+ assert 'value' in process_inputs
+ assert 'test_dag_id' in process_inputs
+ assert 'test_task_id' in process_inputs
+ assert 'test_execution_date' in process_inputs
+ assert 'test_dag_run_id' in process_inputs
- self.assertIn('value', output)
- self.assertIn('test_dag_id', output)
- self.assertIn('test_task_id', output)
- self.assertIn('test_execution_date', output)
- self.assertIn('test_dag_run_id', output)
+ assert 'value' in output
+ assert 'test_dag_id' in output
+ assert 'test_task_id' in output
+ assert 'test_execution_date' in output
+ assert 'test_dag_run_id' in output
@mock.patch('airflow.providers.apache.hive.hooks.hive.HiveCliHook.run_cli')
def test_load_file_without_create_table(self, mock_run_cli):
@@ -262,16 +263,16 @@ def test_load_df(self, mock_to_csv, mock_load_file):
assert mock_to_csv.call_count == 1
kwargs = mock_to_csv.call_args[1]
- self.assertEqual(kwargs["header"], False)
- self.assertEqual(kwargs["index"], False)
- self.assertEqual(kwargs["sep"], delimiter)
+ assert kwargs["header"] is False
+ assert kwargs["index"] is False
+ assert kwargs["sep"] == delimiter
assert mock_load_file.call_count == 1
kwargs = mock_load_file.call_args[1]
- self.assertEqual(kwargs["delimiter"], delimiter)
- self.assertEqual(kwargs["field_dict"], {"c": "STRING"})
- self.assertTrue(isinstance(kwargs["field_dict"], OrderedDict))
- self.assertEqual(kwargs["table"], table)
+ assert kwargs["delimiter"] == delimiter
+ assert kwargs["field_dict"] == {"c": "STRING"}
+ assert isinstance(kwargs["field_dict"], OrderedDict)
+ assert kwargs["table"] == table
@mock.patch('airflow.providers.apache.hive.hooks.hive.HiveCliHook.load_file')
@mock.patch('pandas.DataFrame.to_csv')
@@ -284,8 +285,8 @@ def test_load_df_with_optional_parameters(self, mock_to_csv, mock_load_file):
assert mock_load_file.call_count == 1
kwargs = mock_load_file.call_args[1]
- self.assertEqual(kwargs["create"], create)
- self.assertEqual(kwargs["recreate"], recreate)
+ assert kwargs["create"] == create
+ assert kwargs["recreate"] == recreate
@mock.patch('airflow.providers.apache.hive.hooks.hive.HiveCliHook.run_cli')
def test_load_df_with_data_types(self, mock_run_cli):
@@ -332,11 +333,11 @@ def test_get_max_partition_from_empty_part_specs(self):
max_partition = HiveMetastoreHook._get_max_partition_from_part_specs(
[], 'key1', self.VALID_FILTER_MAP
)
- self.assertIsNone(max_partition)
+ assert max_partition is None
# @mock.patch('airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook', 'get_metastore_client')
def test_get_max_partition_from_valid_part_specs_and_invalid_filter_map(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
HiveMetastoreHook._get_max_partition_from_part_specs(
[{'key1': 'value1', 'key2': 'value2'}, {'key1': 'value3', 'key2': 'value4'}],
'key1',
@@ -344,7 +345,7 @@ def test_get_max_partition_from_valid_part_specs_and_invalid_filter_map(self):
)
def test_get_max_partition_from_valid_part_specs_and_invalid_partition_key(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
HiveMetastoreHook._get_max_partition_from_part_specs(
[{'key1': 'value1', 'key2': 'value2'}, {'key1': 'value3', 'key2': 'value4'}],
'key3',
@@ -352,7 +353,7 @@ def test_get_max_partition_from_valid_part_specs_and_invalid_partition_key(self)
)
def test_get_max_partition_from_valid_part_specs_and_none_partition_key(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
HiveMetastoreHook._get_max_partition_from_part_specs(
[{'key1': 'value1', 'key2': 'value2'}, {'key1': 'value3', 'key2': 'value4'}],
None,
@@ -365,7 +366,7 @@ def test_get_max_partition_from_valid_part_specs_and_none_filter_map(self):
)
# No partition will be filtered out.
- self.assertEqual(max_partition, 'value3')
+ assert max_partition == 'value3'
def test_get_max_partition_from_valid_part_specs(self):
max_partition = HiveMetastoreHook._get_max_partition_from_part_specs(
@@ -373,7 +374,7 @@ def test_get_max_partition_from_valid_part_specs(self):
'key1',
self.VALID_FILTER_MAP,
)
- self.assertEqual(max_partition, 'value1')
+ assert max_partition == 'value1'
def test_get_max_partition_from_valid_part_specs_return_type(self):
max_partition = HiveMetastoreHook._get_max_partition_from_part_specs(
@@ -381,7 +382,7 @@ def test_get_max_partition_from_valid_part_specs_return_type(self):
'key1',
self.VALID_FILTER_MAP,
)
- self.assertIsInstance(max_partition, str)
+ assert isinstance(max_partition, str)
@mock.patch(
"airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook.get_connection",
@@ -399,7 +400,7 @@ def test_get_conn(self):
find_valid_server.return_value = mock.MagicMock(return_value={})
metastore_hook = HiveMetastoreHook()
- self.assertIsInstance(metastore_hook.get_conn(), HMSClient)
+ assert isinstance(metastore_hook.get_conn(), HMSClient)
def test_check_for_partition(self):
# Check for existent partition.
@@ -412,7 +413,7 @@ def test_check_for_partition(self):
metastore.get_partitions_by_filter = mock.MagicMock(return_value=[fake_partition])
- self.assertTrue(self.hook.check_for_partition(self.database, self.table, partition))
+ assert self.hook.check_for_partition(self.database, self.table, partition)
metastore.get_partitions_by_filter(self.database, self.table, partition, 1)
@@ -420,7 +421,7 @@ def test_check_for_partition(self):
missing_partition = f"{self.partition_by}='{self.next_day}'"
metastore.get_partitions_by_filter = mock.MagicMock(return_value=[])
- self.assertFalse(self.hook.check_for_partition(self.database, self.table, missing_partition))
+ assert not self.hook.check_for_partition(self.database, self.table, missing_partition)
metastore.get_partitions_by_filter.assert_called_with(self.database, self.table, missing_partition, 1)
@@ -432,7 +433,7 @@ def test_check_for_named_partition(self):
self.hook.metastore.__enter__().check_for_named_partition = mock.MagicMock(return_value=True)
- self.assertTrue(self.hook.check_for_named_partition(self.database, self.table, partition))
+ assert self.hook.check_for_named_partition(self.database, self.table, partition)
self.hook.metastore.__enter__().check_for_named_partition.assert_called_with(
self.database, self.table, partition
@@ -443,7 +444,7 @@ def test_check_for_named_partition(self):
self.hook.metastore.__enter__().check_for_named_partition = mock.MagicMock(return_value=False)
- self.assertFalse(self.hook.check_for_named_partition(self.database, self.table, missing_partition))
+ assert not self.hook.check_for_named_partition(self.database, self.table, missing_partition)
self.hook.metastore.__enter__().check_for_named_partition.assert_called_with(
self.database, self.table, missing_partition
)
@@ -490,8 +491,8 @@ def test_get_partitions(self):
metastore.get_partitions = mock.MagicMock(return_value=[fake_partition])
partitions = self.hook.get_partitions(schema=self.database, table_name=self.table)
- self.assertEqual(len(partitions), 1)
- self.assertEqual(partitions, [{self.partition_by: DEFAULT_DATE_DS}])
+ assert len(partitions) == 1
+ assert partitions == [{self.partition_by: DEFAULT_DATE_DS}]
metastore.get_table.assert_called_with(dbname=self.database, tbl_name=self.table)
metastore.get_partitions.assert_called_with(
@@ -514,7 +515,7 @@ def test_max_partition(self):
partition = self.hook.max_partition(
schema=self.database, table_name=self.table, field=self.partition_by, filter_map=filter_map
)
- self.assertEqual(partition, DEFAULT_DATE_DS)
+ assert partition == DEFAULT_DATE_DS
metastore.get_table.assert_called_with(dbname=self.database, tbl_name=self.table)
metastore.get_partition_names.assert_called_with(
@@ -526,7 +527,7 @@ def test_table_exists(self):
# Test with existent table.
self.hook.metastore.__enter__().get_table = mock.MagicMock(return_value=True)
- self.assertTrue(self.hook.table_exists(self.table, db=self.database))
+ assert self.hook.table_exists(self.table, db=self.database)
self.hook.metastore.__enter__().get_table.assert_called_with(
dbname='airflow', tbl_name='static_babynames_partitioned'
)
@@ -534,7 +535,7 @@ def test_table_exists(self):
# Test with non-existent table.
self.hook.metastore.__enter__().get_table = mock.MagicMock(side_effect=Exception())
- self.assertFalse(self.hook.table_exists("does-not-exist"))
+ assert not self.hook.table_exists("does-not-exist")
self.hook.metastore.__enter__().get_table.assert_called_with(
dbname='default', tbl_name='does-not-exist'
)
@@ -624,7 +625,7 @@ def test_get_records(self):
):
results = hook.get_records(query, schema=self.database)
- self.assertListEqual(results, [(1, 1), (2, 2)])
+ assert results == [(1, 1), (2, 2)]
hook.get_conn.assert_called_with(self.database)
hook.mock_cursor.execute.assert_any_call('set airflow.ctx.dag_id=test_dag_id')
@@ -651,8 +652,8 @@ def test_get_pandas_df(self):
):
df = hook.get_pandas_df(query, schema=self.database)
- self.assertEqual(len(df), 2)
- self.assertListEqual(df["hive_server_hook.a"].values.tolist(), [1, 2])
+ assert len(df) == 2
+ assert df["hive_server_hook.a"].values.tolist() == [1, 2]
hook.get_conn.assert_called_with(self.database)
hook.mock_cursor.execute.assert_any_call('set airflow.ctx.dag_id=test_dag_id')
@@ -668,7 +669,7 @@ def test_get_results_header(self):
query = f"SELECT * FROM {self.table}"
results = hook.get_results(query, schema=self.database)
- self.assertListEqual([col[0] for col in results['header']], self.columns)
+ assert [col[0] for col in results['header']] == self.columns
def test_get_results_data(self):
hook = MockHiveServer2Hook()
@@ -676,7 +677,7 @@ def test_get_results_data(self):
query = f"SELECT * FROM {self.table}"
results = hook.get_results(query, schema=self.database)
- self.assertListEqual(results['data'], [(1, 1), (2, 2)])
+ assert results['data'] == [(1, 1), (2, 2)]
def test_to_csv(self):
hook = MockHiveServer2Hook()
@@ -704,9 +705,9 @@ def test_to_csv(self):
fetch_size=2,
)
df = pd.read_csv(csv_filepath, sep=',')
- self.assertListEqual(df.columns.tolist(), self.columns)
- self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2])
- self.assertEqual(len(df), 2)
+ assert df.columns.tolist() == self.columns
+ assert df[self.columns[0]].values.tolist() == [1, 2]
+ assert len(df) == 2
def test_multi_statements(self):
sqls = [
@@ -730,7 +731,7 @@ def test_multi_statements(self):
):
# df = hook.get_pandas_df(query, schema=self.database)
results = hook.get_records(sqls, schema=self.database)
- self.assertListEqual(results, [(1, 1), (2, 2)])
+ assert results == [(1, 1), (2, 2)]
# self.assertEqual(len(df), 2)
# self.assertListEqual(df["hive_server_hook.a"].values.tolist(), [1, 2])
@@ -790,11 +791,11 @@ def test_get_results_with_hive_conf(self):
output = '\n'.join(
res_tuple[0] for res_tuple in hook.get_results(hql=hql, hive_conf={'key': 'value'})['data']
)
- self.assertIn('value', output)
- self.assertIn('test_dag_id', output)
- self.assertIn('test_task_id', output)
- self.assertIn('test_execution_date', output)
- self.assertIn('test_dag_run_id', output)
+ assert 'value' in output
+ assert 'test_dag_id' in output
+ assert 'test_task_id' in output
+ assert 'test_execution_date' in output
+ assert 'test_dag_run_id' in output
class TestHiveCli(unittest.TestCase):
@@ -816,4 +817,4 @@ def test_get_proxy_user_value(self):
result = hook._prepare_cli_cmd()
# Verify
- self.assertIn('hive.server2.proxy.user=a_user_proxy', result[2])
+ assert 'hive.server2.proxy.user=a_user_proxy' in result[2]
diff --git a/tests/providers/apache/hive/operators/test_hive.py b/tests/providers/apache/hive/operators/test_hive.py
index 6541c8561b18f..548c96c773871 100644
--- a/tests/providers/apache/hive/operators/test_hive.py
+++ b/tests/providers/apache/hive/operators/test_hive.py
@@ -41,7 +41,7 @@ def test_hive_airflow_default_config_queue(self):
# just check that the correct default value in test_default.cfg is used
test_config_hive_mapred_queue = conf.get('hive', 'default_hive_mapred_queue')
- self.assertEqual(op.get_hook().mapred_queue, test_config_hive_mapred_queue)
+ assert op.get_hook().mapred_queue == test_config_hive_mapred_queue
def test_hive_airflow_default_config_queue_override(self):
specific_mapred_queue = 'default'
@@ -54,7 +54,7 @@ def test_hive_airflow_default_config_queue_override(self):
dag=self.dag,
)
- self.assertEqual(op.get_hook().mapred_queue, specific_mapred_queue)
+ assert op.get_hook().mapred_queue == specific_mapred_queue
class HiveOperatorTest(TestHiveEnvironment):
@@ -64,7 +64,7 @@ def test_hiveconf_jinja_translate(self):
hiveconf_jinja_translate=True, task_id='dry_run_basic_hql', hql=hql, dag=self.dag
)
op.prepare_template()
- self.assertEqual(op.hql, "SELECT {{ num_col }} FROM {{ table }};")
+ assert op.hql == "SELECT {{ num_col }} FROM {{ table }};"
def test_hiveconf(self):
hql = "SELECT * FROM ${hiveconf:table} PARTITION (${hiveconf:day});"
@@ -75,7 +75,7 @@ def test_hiveconf(self):
dag=self.dag,
)
op.prepare_template()
- self.assertEqual(op.hql, "SELECT * FROM ${hiveconf:table} PARTITION (${hiveconf:day});")
+ assert op.hql == "SELECT * FROM ${hiveconf:table} PARTITION (${hiveconf:day});"
@mock.patch('airflow.providers.apache.hive.operators.hive.HiveOperator.get_hook')
def test_mapred_job_name(self, mock_get_hook):
@@ -89,11 +89,11 @@ def test_mapred_job_name(self, mock_get_hook):
fake_context = {'ti': fake_ti}
op.execute(fake_context)
- self.assertEqual(
+ assert (
"Airflow HiveOperator task for {}.{}.{}.{}".format(
fake_ti.hostname, self.dag.dag_id, op.task_id, fake_execution_date.isoformat()
- ),
- mock_hook.mapred_job_name,
+ )
+ == mock_hook.mapred_job_name
)
diff --git a/tests/providers/apache/hive/operators/test_hive_stats.py b/tests/providers/apache/hive/operators/test_hive_stats.py
index cb9814e368548..d38e007822118 100644
--- a/tests/providers/apache/hive/operators/test_hive_stats.py
+++ b/tests/providers/apache/hive/operators/test_hive_stats.py
@@ -22,6 +22,8 @@
from collections import OrderedDict
from unittest.mock import patch
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.apache.hive.operators.hive_stats import HiveStatsCollectionOperator
from tests.providers.apache.hive import DEFAULT_DATE, DEFAULT_DATE_DS, TestHiveEnvironment
@@ -54,7 +56,7 @@ def test_get_default_exprs(self):
default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, None)
- self.assertEqual(default_exprs, {(col, 'non_null'): f'COUNT({col})'})
+ assert default_exprs == {(col, 'non_null'): f'COUNT({col})'}
def test_get_default_exprs_excluded_cols(self):
col = 'excluded_col'
@@ -62,23 +64,20 @@ def test_get_default_exprs_excluded_cols(self):
default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, None)
- self.assertEqual(default_exprs, {})
+ assert default_exprs == {}
def test_get_default_exprs_number(self):
col = 'col'
for col_type in ['double', 'int', 'bigint', 'float']:
default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, col_type)
- self.assertEqual(
- default_exprs,
- {
- (col, 'avg'): f'AVG({col})',
- (col, 'max'): f'MAX({col})',
- (col, 'min'): f'MIN({col})',
- (col, 'non_null'): f'COUNT({col})',
- (col, 'sum'): f'SUM({col})',
- },
- )
+ assert default_exprs == {
+ (col, 'avg'): f'AVG({col})',
+ (col, 'max'): f'MAX({col})',
+ (col, 'min'): f'MIN({col})',
+ (col, 'non_null'): f'COUNT({col})',
+ (col, 'sum'): f'SUM({col})',
+ }
def test_get_default_exprs_boolean(self):
col = 'col'
@@ -86,14 +85,11 @@ def test_get_default_exprs_boolean(self):
default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, col_type)
- self.assertEqual(
- default_exprs,
- {
- (col, 'false'): f'SUM(CASE WHEN NOT {col} THEN 1 ELSE 0 END)',
- (col, 'non_null'): f'COUNT({col})',
- (col, 'true'): f'SUM(CASE WHEN {col} THEN 1 ELSE 0 END)',
- },
- )
+ assert default_exprs == {
+ (col, 'false'): f'SUM(CASE WHEN NOT {col} THEN 1 ELSE 0 END)',
+ (col, 'non_null'): f'COUNT({col})',
+ (col, 'true'): f'SUM(CASE WHEN {col} THEN 1 ELSE 0 END)',
+ }
def test_get_default_exprs_string(self):
col = 'col'
@@ -101,14 +97,11 @@ def test_get_default_exprs_string(self):
default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, col_type)
- self.assertEqual(
- default_exprs,
- {
- (col, 'approx_distinct'): f'APPROX_DISTINCT({col})',
- (col, 'len'): f'SUM(CAST(LENGTH({col}) AS BIGINT))',
- (col, 'non_null'): f'COUNT({col})',
- },
- )
+ assert default_exprs == {
+ (col, 'approx_distinct'): f'APPROX_DISTINCT({col})',
+ (col, 'len'): f'SUM(CAST(LENGTH({col}) AS BIGINT))',
+ (col, 'non_null'): f'COUNT({col})',
+ }
@patch('airflow.providers.apache.hive.operators.hive_stats.json.dumps')
@patch('airflow.providers.apache.hive.operators.hive_stats.MySqlHook')
@@ -265,7 +258,8 @@ def test_execute_no_query_results(self, mock_hive_metastore_hook, mock_presto_ho
mock_mysql_hook.return_value.get_records.return_value = False
mock_presto_hook.return_value.get_first.return_value = None
- self.assertRaises(AirflowException, HiveStatsCollectionOperator(**self.kwargs).execute, context={})
+ with pytest.raises(AirflowException):
+ HiveStatsCollectionOperator(**self.kwargs).execute(context={})
@patch('airflow.providers.apache.hive.operators.hive_stats.json.dumps')
@patch('airflow.providers.apache.hive.operators.hive_stats.MySqlHook')
@@ -334,7 +328,7 @@ def test_runs_for_hive_stats(self, mock_hive_metastore_hook):
raw_stats_select_query = mock_mysql_hook.get_records.call_args_list[0][0][0]
actual_stats_select_query = re.sub(r'\s{2,}', ' ', raw_stats_select_query).strip()
- self.assertEqual(expected_stats_select_query, actual_stats_select_query)
+ assert expected_stats_select_query == actual_stats_select_query
insert_rows_val = [
(
diff --git a/tests/providers/apache/hive/sensors/test_named_hive_partition.py b/tests/providers/apache/hive/sensors/test_named_hive_partition.py
index d5887181485f2..a9eea721f95b8 100644
--- a/tests/providers/apache/hive/sensors/test_named_hive_partition.py
+++ b/tests/providers/apache/hive/sensors/test_named_hive_partition.py
@@ -21,6 +21,8 @@
from datetime import timedelta
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowSensorTimeout
from airflow.models.dag import DAG
from airflow.providers.apache.hive.sensors.named_hive_partition import NamedHivePartitionSensor
@@ -72,13 +74,13 @@ def test_parse_partition_name_correct(self):
partition = 'ds=2016-01-01/state=IT'
name = f'{schema}.{table}/{partition}'
parsed_schema, parsed_table, parsed_partition = NamedHivePartitionSensor.parse_partition_name(name)
- self.assertEqual(schema, parsed_schema)
- self.assertEqual(table, parsed_table)
- self.assertEqual(partition, parsed_partition)
+ assert schema == parsed_schema
+ assert table == parsed_table
+ assert partition == parsed_partition
def test_parse_partition_name_incorrect(self):
name = 'incorrect.name'
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
NamedHivePartitionSensor.parse_partition_name(name)
def test_parse_partition_name_default(self):
@@ -86,9 +88,9 @@ def test_parse_partition_name_default(self):
partition = 'ds=2016-01-01/state=IT'
name = f'{table}/{partition}'
parsed_schema, parsed_table, parsed_partition = NamedHivePartitionSensor.parse_partition_name(name)
- self.assertEqual('default', parsed_schema)
- self.assertEqual(table, parsed_table)
- self.assertEqual(partition, parsed_partition)
+ assert 'default' == parsed_schema
+ assert table == parsed_table
+ assert partition == parsed_partition
def test_poke_existing(self):
self.hook.metastore.__enter__().check_for_named_partition.return_value = True
@@ -100,7 +102,7 @@ def test_poke_existing(self):
hook=self.hook,
dag=self.dag,
)
- self.assertTrue(sensor.poke(None))
+ assert sensor.poke(None)
self.hook.metastore.__enter__().check_for_named_partition.assert_called_with(
self.database, self.table, f"{self.partition_by}={DEFAULT_DATE_DS}"
)
@@ -115,7 +117,7 @@ def test_poke_non_existing(self):
hook=self.hook,
dag=self.dag,
)
- self.assertFalse(sensor.poke(None))
+ assert not sensor.poke(None)
self.hook.metastore.__enter__().check_for_named_partition.assert_called_with(
self.database, self.table, f"{self.partition_by}={self.next_day}"
)
@@ -165,12 +167,12 @@ def test_parses_partitions_with_periods(self):
name = NamedHivePartitionSensor.parse_partition_name(
partition="schema.table/part1=this.can.be.an.issue/part2=ok"
)
- self.assertEqual(name[0], "schema")
- self.assertEqual(name[1], "table")
- self.assertEqual(name[2], "part1=this.can.be.an.issue/part2=ok")
+ assert name[0] == "schema"
+ assert name[1] == "table"
+ assert name[2] == "part1=this.can.be.an.issue/part2=ok"
def test_times_out_on_nonexistent_partition(self):
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
mock_hive_metastore_hook = MockHiveMetastoreHook()
mock_hive_metastore_hook.check_for_named_partition = mock.MagicMock(return_value=False)
diff --git a/tests/providers/apache/hive/transfers/test_hive_to_mysql.py b/tests/providers/apache/hive/transfers/test_hive_to_mysql.py
index 4253325eec475..e8d75ba05c37f 100644
--- a/tests/providers/apache/hive/transfers/test_hive_to_mysql.py
+++ b/tests/providers/apache/hive/transfers/test_hive_to_mysql.py
@@ -156,7 +156,7 @@ def test_hive_to_mysql(self):
raw_select_name_query = mock_hive_hook.get_records.call_args_list[0][0][0]
actual_select_name_query = re.sub(r'\s{2,}', ' ', raw_select_name_query).strip()
expected_select_name_query = 'SELECT name FROM airflow.static_babynames LIMIT 100'
- self.assertEqual(expected_select_name_query, actual_select_name_query)
+ assert expected_select_name_query == actual_select_name_query
actual_hive_conf = mock_hive_hook.get_records.call_args_list[0][1]['hive_conf']
expected_hive_conf = {
@@ -165,7 +165,7 @@ def test_hive_to_mysql(self):
'airflow.ctx.task_id': 'hive_to_mysql_check',
'airflow.ctx.execution_date': '2015-01-01T00:00:00+00:00',
}
- self.assertEqual(expected_hive_conf, actual_hive_conf)
+ assert expected_hive_conf == actual_hive_conf
expected_mysql_preoperator = [
'DROP TABLE IF EXISTS test_static_babynames;',
diff --git a/tests/providers/apache/hive/transfers/test_mssql_to_hive.py b/tests/providers/apache/hive/transfers/test_mssql_to_hive.py
index dbbf782d3b5be..671cc8b855bae 100644
--- a/tests/providers/apache/hive/transfers/test_mssql_to_hive.py
+++ b/tests/providers/apache/hive/transfers/test_mssql_to_hive.py
@@ -44,24 +44,24 @@ def test_type_map_binary(self):
# pylint: disable=c-extension-no-member
mapped_type = MsSqlToHiveOperator(**self.kwargs).type_map(pymssql.BINARY.value)
- self.assertEqual(mapped_type, 'INT')
+ assert mapped_type == 'INT'
def test_type_map_decimal(self):
# pylint: disable=c-extension-no-member
mapped_type = MsSqlToHiveOperator(**self.kwargs).type_map(pymssql.DECIMAL.value)
- self.assertEqual(mapped_type, 'FLOAT')
+ assert mapped_type == 'FLOAT'
def test_type_map_number(self):
# pylint: disable=c-extension-no-member
mapped_type = MsSqlToHiveOperator(**self.kwargs).type_map(pymssql.NUMBER.value)
- self.assertEqual(mapped_type, 'INT')
+ assert mapped_type == 'INT'
def test_type_map_string(self):
mapped_type = MsSqlToHiveOperator(**self.kwargs).type_map(None)
- self.assertEqual(mapped_type, 'STRING')
+ assert mapped_type == 'STRING'
@patch('airflow.providers.apache.hive.transfers.mssql_to_hive.csv')
@patch('airflow.providers.apache.hive.transfers.mssql_to_hive.NamedTemporaryFile')
diff --git a/tests/providers/apache/hive/transfers/test_mysql_to_hive.py b/tests/providers/apache/hive/transfers/test_mysql_to_hive.py
index 19a368049726a..b921c7464a87c 100644
--- a/tests/providers/apache/hive/transfers/test_mysql_to_hive.py
+++ b/tests/providers/apache/hive/transfers/test_mysql_to_hive.py
@@ -335,7 +335,7 @@ def test_mysql_to_hive_type_conversion(self, mock_load_file):
ordered_dict["c3"] = "BIGINT"
ordered_dict["c4"] = "DECIMAL(38,0)"
ordered_dict["c5"] = "TIMESTAMP"
- self.assertEqual(mock_load_file.call_args[1]["field_dict"], ordered_dict)
+ assert mock_load_file.call_args[1]["field_dict"] == ordered_dict
finally:
with hook.get_conn() as conn:
conn.execute(f"DROP TABLE IF EXISTS {mysql_table}")
@@ -399,7 +399,7 @@ def test_mysql_to_hive_verify_csv_special_char(self, mock_popen, mock_temp_dir):
hive_hook = MockHiveServer2Hook(connection_cursor=mock_cursor)
result = hive_hook.get_records(f"SELECT * FROM {hive_table}")
- self.assertEqual(result[0], db_record)
+ assert result[0] == db_record
hive_cmd = [
'beeline',
@@ -512,7 +512,7 @@ def test_mysql_to_hive_verify_loaded_values(self, mock_popen, mock_temp_dir):
hive_hook = MockHiveServer2Hook(connection_cursor=mock_cursor)
result = hive_hook.get_records(f"SELECT * FROM {hive_table}")
- self.assertEqual(result[0], minmax)
+ assert result[0] == minmax
hive_cmd = [
'beeline',
diff --git a/tests/providers/apache/hive/transfers/test_s3_to_hive.py b/tests/providers/apache/hive/transfers/test_s3_to_hive.py
index 473b889335b65..f22d1863e3d99 100644
--- a/tests/providers/apache/hive/transfers/test_s3_to_hive.py
+++ b/tests/providers/apache/hive/transfers/test_s3_to_hive.py
@@ -28,6 +28,8 @@
from tempfile import NamedTemporaryFile, mkdtemp
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.apache.hive.transfers.s3_to_hive import S3ToHiveOperator
@@ -153,38 +155,32 @@ def _check_file_equality(fn_1, fn_2, ext):
def test_bad_parameters(self):
self.kwargs['check_headers'] = True
self.kwargs['headers'] = False
- self.assertRaisesRegex(AirflowException, "To check_headers.*", S3ToHiveOperator, **self.kwargs)
+ with pytest.raises(AirflowException, match="To check_headers.*"):
+ S3ToHiveOperator(**self.kwargs)
def test__get_top_row_as_list(self):
self.kwargs['delimiter'] = '\t'
fn_txt = self._get_fn('.txt', True)
header_list = S3ToHiveOperator(**self.kwargs)._get_top_row_as_list(fn_txt)
- self.assertEqual(
- header_list, ['Sno', 'Some,Text'], msg="Top row from file doesnt matched expected value"
- )
+ assert header_list == ['Sno', 'Some,Text'], "Top row from file doesnt matched expected value"
self.kwargs['delimiter'] = ','
header_list = S3ToHiveOperator(**self.kwargs)._get_top_row_as_list(fn_txt)
- self.assertEqual(
- header_list, ['Sno\tSome', 'Text'], msg="Top row from file doesnt matched expected value"
- )
+ assert header_list == ['Sno\tSome', 'Text'], "Top row from file doesnt matched expected value"
def test__match_headers(self):
self.kwargs['field_dict'] = OrderedDict([('Sno', 'BIGINT'), ('Some,Text', 'STRING')])
- self.assertTrue(
- S3ToHiveOperator(**self.kwargs)._match_headers(['Sno', 'Some,Text']),
- msg="Header row doesnt match expected value",
- )
+ assert S3ToHiveOperator(**self.kwargs)._match_headers(
+ ['Sno', 'Some,Text']
+ ), "Header row doesnt match expected value"
# Testing with different column order
- self.assertFalse(
- S3ToHiveOperator(**self.kwargs)._match_headers(['Some,Text', 'Sno']),
- msg="Header row doesnt match expected value",
- )
+ assert not S3ToHiveOperator(**self.kwargs)._match_headers(
+ ['Some,Text', 'Sno']
+ ), "Header row doesnt match expected value"
# Testing with extra column in header
- self.assertFalse(
- S3ToHiveOperator(**self.kwargs)._match_headers(['Sno', 'Some,Text', 'ExtraColumn']),
- msg="Header row doesnt match expected value",
- )
+ assert not S3ToHiveOperator(**self.kwargs)._match_headers(
+ ['Sno', 'Some,Text', 'ExtraColumn']
+ ), "Header row doesnt match expected value"
def test__delete_top_row_and_compress(self):
s32hive = S3ToHiveOperator(**self.kwargs)
@@ -192,15 +188,11 @@ def test__delete_top_row_and_compress(self):
fn_txt = self._get_fn('.txt', True)
gz_txt_nh = s32hive._delete_top_row_and_compress(fn_txt, '.gz', self.tmp_dir)
fn_gz = self._get_fn('.gz', False)
- self.assertTrue(
- self._check_file_equality(gz_txt_nh, fn_gz, '.gz'), msg="gz Compressed file not as expected"
- )
+ assert self._check_file_equality(gz_txt_nh, fn_gz, '.gz'), "gz Compressed file not as expected"
# Testing bz2 file type
bz2_txt_nh = s32hive._delete_top_row_and_compress(fn_txt, '.bz2', self.tmp_dir)
fn_bz2 = self._get_fn('.bz2', False)
- self.assertTrue(
- self._check_file_equality(bz2_txt_nh, fn_bz2, '.bz2'), msg="bz2 Compressed file not as expected"
- )
+ assert self._check_file_equality(bz2_txt_nh, fn_bz2, '.bz2'), "bz2 Compressed file not as expected"
@unittest.skipIf(mock is None, 'mock package not present')
@unittest.skipIf(mock_s3 is None, 'moto package not present')
@@ -227,7 +219,7 @@ def test_execute(self, mock_hiveclihook):
# against expected file output
mock_hiveclihook().load_file.side_effect = lambda *args, **kwargs: self.assertTrue(
self._check_file_equality(args[0], op_fn, ext),
- msg=f'{ext} output file not as expected',
+ f'{ext} output file not as expected',
)
# Execute S3ToHiveTransfer
s32hive = S3ToHiveOperator(**self.kwargs)
diff --git a/tests/providers/apache/kylin/hooks/test_kylin.py b/tests/providers/apache/kylin/hooks/test_kylin.py
index ade147ec3d0f0..c802197d8c357 100644
--- a/tests/providers/apache/kylin/hooks/test_kylin.py
+++ b/tests/providers/apache/kylin/hooks/test_kylin.py
@@ -20,6 +20,7 @@
import unittest
from unittest.mock import MagicMock, patch
+import pytest
from kylinpy.exceptions import KylinCubeError
from airflow.exceptions import AirflowException
@@ -35,7 +36,7 @@ def test_get_job_status(self, mock_job):
job = MagicMock()
job.status = "ERROR"
mock_job.return_value = job
- self.assertEqual(self.hook.get_job_status('123'), "ERROR")
+ assert self.hook.get_job_status('123') == "ERROR"
@patch("kylinpy.Kylin.get_datasource")
def test_cube_run(self, cube_source):
@@ -63,13 +64,12 @@ def invoke_command(self, command, **kwargs):
cube_source.return_value = MockCubeSource()
response_data = {"code": "000", "data": {}}
- self.assertDictEqual(self.hook.cube_run('kylin_sales_cube', 'build'), response_data)
- self.assertDictEqual(self.hook.cube_run('kylin_sales_cube', 'refresh'), response_data)
- self.assertDictEqual(self.hook.cube_run('kylin_sales_cube', 'merge'), response_data)
- self.assertDictEqual(self.hook.cube_run('kylin_sales_cube', 'build_streaming'), response_data)
- self.assertRaises(
- AirflowException,
- self.hook.cube_run,
- 'kylin_sales_cube',
- 'build123',
- )
+ assert self.hook.cube_run('kylin_sales_cube', 'build') == response_data
+ assert self.hook.cube_run('kylin_sales_cube', 'refresh') == response_data
+ assert self.hook.cube_run('kylin_sales_cube', 'merge') == response_data
+ assert self.hook.cube_run('kylin_sales_cube', 'build_streaming') == response_data
+ with pytest.raises(AirflowException):
+ self.hook.cube_run(
+ 'kylin_sales_cube',
+ 'build123',
+ )
diff --git a/tests/providers/apache/kylin/operators/test_kylin_cube.py b/tests/providers/apache/kylin/operators/test_kylin_cube.py
index f7d21a276d3c9..4ad927477378e 100644
--- a/tests/providers/apache/kylin/operators/test_kylin_cube.py
+++ b/tests/providers/apache/kylin/operators/test_kylin_cube.py
@@ -20,6 +20,8 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import TaskInstance
from airflow.models.dag import DAG
@@ -82,13 +84,13 @@ def test_execute(self, mock_hook):
mock_hook.return_value = hook
mock_hook.cube_run.return_value = {}
- self.assertIsNotNone(operator)
- self.assertEqual(self._config['kylin_conn_id'], operator.kylin_conn_id)
- self.assertEqual(self._config['project'], operator.project)
- self.assertEqual(self._config['cube'], operator.cube)
- self.assertEqual(self._config['command'], operator.command)
- self.assertEqual(self._config['start_time'], operator.start_time)
- self.assertEqual(self._config['end_time'], operator.end_time)
+ assert operator is not None
+ assert self._config['kylin_conn_id'] == operator.kylin_conn_id
+ assert self._config['project'] == operator.project
+ assert self._config['cube'] == operator.cube
+ assert self._config['command'] == operator.command
+ assert self._config['start_time'] == operator.start_time
+ assert self._config['end_time'] == operator.end_time
operator.execute(None)
mock_hook.assert_called_once_with(
kylin_conn_id=self._config['kylin_conn_id'], project=self._config['project'], dsn=None
@@ -115,7 +117,7 @@ def test_execute_build(self, mock_hook):
hook.get_job_status.side_effect = ["RUNNING", "RUNNING", "FINISHED"]
mock_hook.return_value = hook
- self.assertEqual(operator.execute(None)['uuid'], "c143e0e4-ac5f-434d-acf3-46b0d15e3dc6")
+ assert operator.execute(None)['uuid'] == "c143e0e4-ac5f-434d-acf3-46b0d15e3dc6"
@patch('airflow.providers.apache.kylin.operators.kylin_cube.KylinHook')
def test_execute_build_status_error(self, mock_hook):
@@ -128,7 +130,8 @@ def test_execute_build_status_error(self, mock_hook):
hook.get_job_status.return_value = "ERROR"
mock_hook.return_value = hook
- self.assertRaises(AirflowException, operator.execute, None)
+ with pytest.raises(AirflowException):
+ operator.execute(None)
@patch('airflow.providers.apache.kylin.operators.kylin_cube.KylinHook')
def test_execute_build_time_out_error(self, mock_hook):
@@ -141,7 +144,8 @@ def test_execute_build_time_out_error(self, mock_hook):
hook.get_job_status.return_value = "RUNNING"
mock_hook.return_value = hook
- self.assertRaises(AirflowException, operator.execute, None)
+ with pytest.raises(AirflowException):
+ operator.execute(None)
def test_render_template(self):
operator = KylinCubeOperator(
@@ -164,8 +168,8 @@ def test_render_template(self):
)
ti = TaskInstance(operator, DEFAULT_DATE)
ti.render_templates()
- self.assertEqual('learn_kylin', getattr(operator, 'project'))
- self.assertEqual('kylin_sales_cube', getattr(operator, 'cube'))
- self.assertEqual('build', getattr(operator, 'command'))
- self.assertEqual('1483200000000', getattr(operator, 'start_time'))
- self.assertEqual('1483286400000', getattr(operator, 'end_time'))
+ assert 'learn_kylin' == getattr(operator, 'project')
+ assert 'kylin_sales_cube' == getattr(operator, 'cube')
+ assert 'build' == getattr(operator, 'command')
+ assert '1483200000000' == getattr(operator, 'start_time')
+ assert '1483286400000' == getattr(operator, 'end_time')
diff --git a/tests/providers/apache/livy/hooks/test_livy.py b/tests/providers/apache/livy/hooks/test_livy.py
index 316b4706e1b91..86b7acaa6fc39 100644
--- a/tests/providers/apache/livy/hooks/test_livy.py
+++ b/tests/providers/apache/livy/hooks/test_livy.py
@@ -19,6 +19,7 @@
import unittest
from unittest.mock import patch
+import pytest
import requests_mock
from requests.exceptions import RequestException
@@ -63,18 +64,18 @@ def test_build_get_hook(self):
hook = LivyHook(livy_conn_id=conn_id)
hook.get_conn()
- self.assertEqual(hook.base_url, expected)
+ assert hook.base_url == expected
@unittest.skip("inherited HttpHook does not handle missing hostname")
def test_missing_host(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
LivyHook(livy_conn_id='missing_host').get_conn()
def test_build_body(self):
with self.subTest('minimal request'):
body = LivyHook.build_post_batch_body(file='appname')
- self.assertEqual(body, {'file': 'appname'})
+ assert body == {'file': 'appname'}
with self.subTest('complex request'):
body = LivyHook.build_post_batch_body(
@@ -96,68 +97,67 @@ def test_build_body(self):
num_executors='10',
)
- self.assertEqual(
- body,
- {
- 'file': 'appname',
- 'className': 'org.example.livy',
- 'proxyUser': 'proxyUser',
- 'args': ['a', '1'],
- 'jars': ['jar1', 'jar2'],
- 'files': ['file1', 'file2'],
- 'pyFiles': ['py1', 'py2'],
- 'archives': ['arch1', 'arch2'],
- 'queue': 'queue',
- 'name': 'name',
- 'conf': {'a': 'b'},
- 'driverCores': 2,
- 'driverMemory': '1M',
- 'executorMemory': '1m',
- 'executorCores': '1',
- 'numExecutors': '10',
- },
- )
+ assert body == {
+ 'file': 'appname',
+ 'className': 'org.example.livy',
+ 'proxyUser': 'proxyUser',
+ 'args': ['a', '1'],
+ 'jars': ['jar1', 'jar2'],
+ 'files': ['file1', 'file2'],
+ 'pyFiles': ['py1', 'py2'],
+ 'archives': ['arch1', 'arch2'],
+ 'queue': 'queue',
+ 'name': 'name',
+ 'conf': {'a': 'b'},
+ 'driverCores': 2,
+ 'driverMemory': '1M',
+ 'executorMemory': '1m',
+ 'executorCores': '1',
+ 'numExecutors': '10',
+ }
def test_parameters_validation(self):
with self.subTest('not a size'):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
LivyHook.build_post_batch_body(file='appname', executor_memory='xxx')
with self.subTest('list of stringables'):
- self.assertEqual(
- LivyHook.build_post_batch_body(file='appname', args=['a', 1, 0.1])['args'], ['a', '1', '0.1']
- )
+ assert LivyHook.build_post_batch_body(file='appname', args=['a', 1, 0.1])['args'] == [
+ 'a',
+ '1',
+ '0.1',
+ ]
def test_validate_size_format(self):
with self.subTest('lower 1'):
- self.assertTrue(LivyHook._validate_size_format('1m'))
+ assert LivyHook._validate_size_format('1m')
with self.subTest('lower 2'):
- self.assertTrue(LivyHook._validate_size_format('1mb'))
+ assert LivyHook._validate_size_format('1mb')
with self.subTest('upper 1'):
- self.assertTrue(LivyHook._validate_size_format('1G'))
+ assert LivyHook._validate_size_format('1G')
with self.subTest('upper 2'):
- self.assertTrue(LivyHook._validate_size_format('1GB'))
+ assert LivyHook._validate_size_format('1GB')
with self.subTest('snake 1'):
- self.assertTrue(LivyHook._validate_size_format('1Gb'))
+ assert LivyHook._validate_size_format('1Gb')
with self.subTest('fullmatch'):
- with self.assertRaises(ValueError):
- self.assertTrue(LivyHook._validate_size_format('1Gb foo'))
+ with pytest.raises(ValueError):
+ assert LivyHook._validate_size_format('1Gb foo')
with self.subTest('missing size'):
- with self.assertRaises(ValueError):
- self.assertTrue(LivyHook._validate_size_format('10'))
+ with pytest.raises(ValueError):
+ assert LivyHook._validate_size_format('10')
with self.subTest('numeric'):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
LivyHook._validate_size_format(1) # noqa
with self.subTest('None'):
- self.assertTrue(LivyHook._validate_size_format(None)) # noqa
+ assert LivyHook._validate_size_format(None) # noqa
def test_validate_list_of_stringables(self):
with self.subTest('valid list'):
@@ -179,27 +179,27 @@ def test_validate_list_of_stringables(self):
self.fail("Exception raised")
with self.subTest('dict'):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
LivyHook._validate_list_of_stringables({'a': 'a'})
with self.subTest('invalid element'):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
LivyHook._validate_list_of_stringables([1, {}])
with self.subTest('dict'):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
LivyHook._validate_list_of_stringables([1, None])
with self.subTest('None'):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
LivyHook._validate_list_of_stringables(None) # noqa
with self.subTest('int'):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
LivyHook._validate_list_of_stringables(1) # noqa
with self.subTest('string'):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
LivyHook._validate_list_of_stringables('string')
def test_validate_extra_conf(self):
@@ -222,23 +222,23 @@ def test_validate_extra_conf(self):
self.fail("Exception raised")
with self.subTest('not a dict 1'):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
LivyHook._validate_extra_conf('k1=v1') # noqa
with self.subTest('not a dict 2'):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
LivyHook._validate_extra_conf([('k1', 'v1'), ('k2', 0)]) # noqa
with self.subTest('nested dict'):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
LivyHook._validate_extra_conf({'outer': {'inner': 'val'}})
with self.subTest('empty items'):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
LivyHook._validate_extra_conf({'has_val': 'val', 'no_val': None})
with self.subTest('empty string'):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
LivyHook._validate_extra_conf({'has_val': 'val', 'no_val': ''})
@patch('airflow.providers.apache.livy.hooks.livy.LivyHook.run_method')
@@ -259,11 +259,11 @@ def test_post_batch_arguments(self, mock_request):
)
request_args = mock_request.call_args[1]
- self.assertIn('data', request_args)
- self.assertIsInstance(request_args['data'], str)
+ assert 'data' in request_args
+ assert isinstance(request_args['data'], str)
- self.assertIsInstance(resp, int)
- self.assertEqual(resp, BATCH_ID)
+ assert isinstance(resp, int)
+ assert resp == BATCH_ID
@requests_mock.mock()
def test_post_batch_success(self, mock):
@@ -276,15 +276,15 @@ def test_post_batch_success(self, mock):
resp = LivyHook().post_batch(file='sparkapp')
- self.assertIsInstance(resp, int)
- self.assertEqual(resp, BATCH_ID)
+ assert isinstance(resp, int)
+ assert resp == BATCH_ID
@requests_mock.mock()
def test_post_batch_fail(self, mock):
mock.register_uri('POST', '//livy:8998/batches', json={}, status_code=400, reason='ERROR')
hook = LivyHook()
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
hook.post_batch(file='sparkapp')
@requests_mock.mock()
@@ -294,8 +294,8 @@ def test_get_batch_success(self, mock):
hook = LivyHook()
resp = hook.get_batch(BATCH_ID)
- self.assertIsInstance(resp, dict)
- self.assertIn('id', resp)
+ assert isinstance(resp, dict)
+ assert 'id' in resp
@requests_mock.mock()
def test_get_batch_fail(self, mock):
@@ -308,12 +308,12 @@ def test_get_batch_fail(self, mock):
)
hook = LivyHook()
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
hook.get_batch(BATCH_ID)
def test_invalid_uri(self):
hook = LivyHook(livy_conn_id='invalid_uri')
- with self.assertRaises(RequestException):
+ with pytest.raises(RequestException):
hook.post_batch(file='sparkapp')
@requests_mock.mock()
@@ -330,8 +330,8 @@ def test_get_batch_state_success(self, mock):
state = LivyHook().get_batch_state(BATCH_ID)
- self.assertIsInstance(state, BatchState)
- self.assertEqual(state, running)
+ assert isinstance(state, BatchState)
+ assert state == running
@requests_mock.mock()
def test_get_batch_state_fail(self, mock):
@@ -340,7 +340,7 @@ def test_get_batch_state_fail(self, mock):
)
hook = LivyHook()
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
hook.get_batch_state(BATCH_ID)
@requests_mock.mock()
@@ -348,13 +348,13 @@ def test_get_batch_state_missing(self, mock):
mock.register_uri('GET', f'//livy:8998/batches/{BATCH_ID}/state', json={}, status_code=200)
hook = LivyHook()
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
hook.get_batch_state(BATCH_ID)
def test_parse_post_response(self):
res_id = LivyHook._parse_post_response({'id': BATCH_ID, 'log': []})
- self.assertEqual(BATCH_ID, res_id)
+ assert BATCH_ID == res_id
@requests_mock.mock()
def test_delete_batch_success(self, mock):
@@ -364,7 +364,7 @@ def test_delete_batch_success(self, mock):
resp = LivyHook().delete_batch(BATCH_ID)
- self.assertEqual(resp, {'msg': 'deleted'})
+ assert resp == {'msg': 'deleted'}
@requests_mock.mock()
def test_delete_batch_fail(self, mock):
@@ -373,7 +373,7 @@ def test_delete_batch_fail(self, mock):
)
hook = LivyHook()
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
hook.delete_batch(BATCH_ID)
@requests_mock.mock()
@@ -381,7 +381,7 @@ def test_missing_batch_id(self, mock):
mock.register_uri('POST', '//livy:8998/batches', json={}, status_code=201)
hook = LivyHook()
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
hook.post_batch(file='sparkapp')
@requests_mock.mock()
@@ -395,7 +395,7 @@ def test_get_batch_validation(self, mock):
# make sure blocked by validation
for val in [None, 'one', {'a': 'b'}]:
with self.subTest(f'get_batch {val}'):
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
hook.get_batch(val)
@requests_mock.mock()
@@ -410,7 +410,7 @@ def test_get_batch_state_validation(self, mock):
for val in [None, 'one', {'a': 'b'}]:
with self.subTest(f'get_batch {val}'):
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
hook.get_batch_state(val)
@requests_mock.mock()
@@ -423,7 +423,7 @@ def test_delete_batch_validation(self, mock):
for val in [None, 'one', {'a': 'b'}]:
with self.subTest(f'get_batch {val}'):
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
hook.delete_batch(val)
def test_check_session_id(self):
@@ -440,9 +440,9 @@ def test_check_session_id(self):
self.fail("")
with self.subTest('None'):
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
LivyHook._validate_session_id(None) # noqa
with self.subTest('random string'):
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
LivyHook._validate_session_id('asd')
diff --git a/tests/providers/apache/livy/operators/test_livy.py b/tests/providers/apache/livy/operators/test_livy.py
index 86a0aaea706be..dae6004c61838 100644
--- a/tests/providers/apache/livy/operators/test_livy.py
+++ b/tests/providers/apache/livy/operators/test_livy.py
@@ -19,6 +19,8 @@
import unittest
from unittest.mock import MagicMock, patch
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.models.dag import DAG
@@ -60,7 +62,7 @@ def side_effect(_):
task.poll_for_termination(BATCH_ID)
mock_livy.assert_called_with(BATCH_ID)
- self.assertEqual(mock_livy.call_count, 3)
+ assert mock_livy.call_count == 3
@patch('airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state')
def test_poll_for_termination_fail(self, mock_livy):
@@ -78,11 +80,11 @@ def side_effect(_):
task = LivyOperator(file='sparkapp', polling_interval=1, dag=self.dag, task_id='livy_example')
task._livy_hook = task.get_hook()
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
task.poll_for_termination(BATCH_ID)
mock_livy.assert_called_with(BATCH_ID)
- self.assertEqual(mock_livy.call_count, 3)
+ assert mock_livy.call_count == 3
@patch(
'airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state',
@@ -100,7 +102,7 @@ def test_execution(self, mock_post, mock_get):
task.execute(context={})
call_args = {k: v for k, v in mock_post.call_args[1].items() if v}
- self.assertEqual(call_args, {'file': 'sparkapp'})
+ assert call_args == {'file': 'sparkapp'}
mock_get.assert_called_once_with(BATCH_ID)
@patch('airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch')
@@ -120,4 +122,4 @@ def test_injected_hook(self):
task = LivyOperator(file='sparkapp', dag=self.dag, task_id='livy_example')
task._livy_hook = def_hook
- self.assertEqual(task.get_hook(), def_hook)
+ assert task.get_hook() == def_hook
diff --git a/tests/providers/apache/livy/sensors/test_livy.py b/tests/providers/apache/livy/sensors/test_livy.py
index 8440dd9fa69ab..654f120ff8c8b 100644
--- a/tests/providers/apache/livy/sensors/test_livy.py
+++ b/tests/providers/apache/livy/sensors/test_livy.py
@@ -43,4 +43,4 @@ def test_poke(self, mock_state):
for state in BatchState:
with self.subTest(state.value):
mock_state.return_value = state
- self.assertEqual(sensor.poke({}), state in LivyHook.TERMINAL_STATES)
+ assert sensor.poke({}) == (state in LivyHook.TERMINAL_STATES)
diff --git a/tests/providers/apache/pig/hooks/test_pig.py b/tests/providers/apache/pig/hooks/test_pig.py
index a714368b88354..12fbc50885293 100644
--- a/tests/providers/apache/pig/hooks/test_pig.py
+++ b/tests/providers/apache/pig/hooks/test_pig.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.providers.apache.pig.hooks.pig import PigCliHook
@@ -52,7 +54,7 @@ def test_run_cli_success(self, popen_mock):
hook = self.pig_hook()
stdout = hook.run_cli("")
- self.assertEqual(stdout, "")
+ assert stdout == ""
@mock.patch('subprocess.Popen')
def test_run_cli_fail(self, popen_mock):
@@ -65,7 +67,8 @@ def test_run_cli_fail(self, popen_mock):
from airflow.exceptions import AirflowException
- self.assertRaises(AirflowException, hook.run_cli, "")
+ with pytest.raises(AirflowException):
+ hook.run_cli("")
@mock.patch('subprocess.Popen')
def test_run_cli_with_properties(self, popen_mock):
@@ -80,11 +83,11 @@ def test_run_cli_with_properties(self, popen_mock):
hook.pig_properties = test_properties
stdout = hook.run_cli("")
- self.assertEqual(stdout, "")
+ assert stdout == ""
popen_first_arg = popen_mock.call_args[0][0]
for pig_prop in test_properties.split():
- self.assertIn(pig_prop, popen_first_arg)
+ assert pig_prop in popen_first_arg
@mock.patch('subprocess.Popen')
def test_run_cli_verbose(self, popen_mock):
@@ -99,7 +102,7 @@ def test_run_cli_verbose(self, popen_mock):
hook = self.pig_hook()
stdout = hook.run_cli("", verbose=True)
- self.assertEqual(stdout, "".join(test_stdout_strings))
+ assert stdout == "".join(test_stdout_strings)
def test_kill_no_sp(self):
sp_mock = mock.Mock()
@@ -107,7 +110,7 @@ def test_kill_no_sp(self):
hook.sub_process = sp_mock
hook.kill()
- self.assertFalse(sp_mock.kill.called)
+ assert not sp_mock.kill.called
def test_kill_sp_done(self):
sp_mock = mock.Mock()
@@ -117,7 +120,7 @@ def test_kill_sp_done(self):
hook.sub_process = sp_mock
hook.kill()
- self.assertFalse(sp_mock.kill.called)
+ assert not sp_mock.kill.called
def test_kill(self):
sp_mock = mock.Mock()
@@ -127,4 +130,4 @@ def test_kill(self):
hook.sub_process = sp_mock
hook.kill()
- self.assertTrue(sp_mock.kill.called)
+ assert sp_mock.kill.called
diff --git a/tests/providers/apache/pig/operators/test_pig.py b/tests/providers/apache/pig/operators/test_pig.py
index 92546d122e175..e391beb88b28f 100644
--- a/tests/providers/apache/pig/operators/test_pig.py
+++ b/tests/providers/apache/pig/operators/test_pig.py
@@ -33,12 +33,12 @@ def test_prepare_template(self):
operator = PigOperator(pig=pig, task_id=task_id)
operator.prepare_template()
- self.assertEqual(pig, operator.pig)
+ assert pig == operator.pig
# converts when pigparams_jinja_translate = true
operator = PigOperator(pig=pig, task_id=task_id, pigparams_jinja_translate=True)
operator.prepare_template()
- self.assertEqual("sh echo {{ DATE }};", operator.pig)
+ assert "sh echo {{ DATE }};" == operator.pig
@mock.patch.object(PigCliHook, 'run_cli')
def test_execute(self, mock_run_cli):
diff --git a/tests/providers/apache/pinot/hooks/test_pinot.py b/tests/providers/apache/pinot/hooks/test_pinot.py
index 00d8397e26364..e7e6a5139f478 100644
--- a/tests/providers/apache/pinot/hooks/test_pinot.py
+++ b/tests/providers/apache/pinot/hooks/test_pinot.py
@@ -179,7 +179,7 @@ def test_run_cli_failure_error_message(self, mock_popen):
mock_popen.return_value = mock_proc
params = ["foo", "bar", "baz"]
- with self.assertRaises(AirflowException, msg=msg):
+ with pytest.raises(AirflowException):
self.db_hook.run_cli(params)
params.insert(0, self.conn.extra_dejson.get('cmd_path'))
mock_popen.assert_called_once_with(
@@ -195,7 +195,7 @@ def test_run_cli_failure_status_code(self, mock_popen):
self.db_hook.pinot_admin_system_exit = True
params = ["foo", "bar", "baz"]
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.db_hook.run_cli(params)
params.insert(0, self.conn.extra_dejson.get('cmd_path'))
env = os.environ.copy()
@@ -232,29 +232,29 @@ def test_get_uri(self):
Test on getting a pinot connection uri
"""
db_hook = self.db_hook()
- self.assertEqual(db_hook.get_uri(), 'http://host:1000/query/sql')
+ assert db_hook.get_uri() == 'http://host:1000/query/sql'
def test_get_conn(self):
"""
Test on getting a pinot connection
"""
conn = self.db_hook().get_conn()
- self.assertEqual(conn.host, 'host')
- self.assertEqual(conn.port, '1000')
- self.assertEqual(conn.conn_type, 'http')
- self.assertEqual(conn.extra_dejson.get('endpoint'), 'query/sql')
+ assert conn.host == 'host'
+ assert conn.port == '1000'
+ assert conn.conn_type == 'http'
+ assert conn.extra_dejson.get('endpoint') == 'query/sql'
def test_get_records(self):
statement = 'SQL'
result_sets = [('row1',), ('row2',)]
self.cur.fetchall.return_value = result_sets
- self.assertEqual(result_sets, self.db_hook().get_records(statement))
+ assert result_sets == self.db_hook().get_records(statement)
def test_get_first(self):
statement = 'SQL'
result_sets = [('row1',), ('row2',)]
self.cur.fetchone.return_value = result_sets[0]
- self.assertEqual(result_sets[0], self.db_hook().get_first(statement))
+ assert result_sets[0] == self.db_hook().get_first(statement)
def test_get_pandas_df(self):
statement = 'SQL'
@@ -263,9 +263,9 @@ def test_get_pandas_df(self):
self.cur.description = [(column,)]
self.cur.fetchall.return_value = result_sets
df = self.db_hook().get_pandas_df(statement)
- self.assertEqual(column, df.columns[0])
+ assert column == df.columns[0]
for i in range(len(result_sets)): # pylint: disable=consider-using-enumerate
- self.assertEqual(result_sets[i][0], df.values.tolist()[i][0])
+ assert result_sets[i][0] == df.values.tolist()[i][0]
class TestPinotDbApiHookIntegration(unittest.TestCase):
@@ -275,4 +275,4 @@ def test_should_return_records(self):
hook = PinotDbApiHook()
sql = "select playerName from baseballStats ORDER BY playerName limit 5"
records = hook.get_records(sql)
- self.assertEqual([["A. Harry"], ["A. Harry"], ["Aaron"], ["Aaron Albert"], ["Aaron Albert"]], records)
+ assert [["A. Harry"], ["A. Harry"], ["Aaron"], ["Aaron Albert"], ["Aaron Albert"]] == records
diff --git a/tests/providers/apache/spark/hooks/test_spark_jdbc.py b/tests/providers/apache/spark/hooks/test_spark_jdbc.py
index bd80bccdd8072..3b3b898861bd4 100644
--- a/tests/providers/apache/spark/hooks/test_spark_jdbc.py
+++ b/tests/providers/apache/spark/hooks/test_spark_jdbc.py
@@ -99,7 +99,7 @@ def test_resolve_jdbc_connection(self):
connection = hook._resolve_jdbc_connection()
# Then
- self.assertEqual(connection, expected_connection)
+ assert connection == expected_connection
def test_build_jdbc_arguments(self):
# Given
@@ -143,7 +143,7 @@ def test_build_jdbc_arguments(self):
'-createTableColumnTypes',
'columnMcColumnFace INTEGER(100), name CHAR(64),comments VARCHAR(1024)',
]
- self.assertEqual(expected_jdbc_arguments, cmd)
+ assert expected_jdbc_arguments == cmd
def test_build_jdbc_arguments_invalid(self):
# Given
diff --git a/tests/providers/apache/spark/hooks/test_spark_sql.py b/tests/providers/apache/spark/hooks/test_spark_sql.py
index 5cc001868a0af..85a5159e7a737 100644
--- a/tests/providers/apache/spark/hooks/test_spark_sql.py
+++ b/tests/providers/apache/spark/hooks/test_spark_sql.py
@@ -21,6 +21,8 @@
from itertools import dropwhile
from unittest.mock import call, patch
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.apache.spark.hooks.spark_sql import SparkSqlHook
@@ -104,24 +106,21 @@ def test_spark_process_runcmd(self, mock_popen):
mock_info.assert_called_once_with('Spark-sql communicates using stdout')
# Then
- self.assertEqual(
- mock_popen.mock_calls[0],
- call(
- [
- 'spark-sql',
- '-e',
- 'SELECT 1',
- '--master',
- 'yarn',
- '--name',
- 'default-name',
- '--verbose',
- '--queue',
- 'default',
- ],
- stderr=-2,
- stdout=-1,
- ),
+ assert mock_popen.mock_calls[0] == call(
+ [
+ 'spark-sql',
+ '-e',
+ 'SELECT 1',
+ '--master',
+ 'yarn',
+ '--name',
+ 'default-name',
+ '--verbose',
+ '--queue',
+ 'default',
+ ],
+ stderr=-2,
+ stdout=-1,
)
@patch('airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen')
@@ -134,26 +133,23 @@ def test_spark_process_runcmd_with_str(self, mock_popen):
hook.run_query('--deploy-mode cluster')
# Then
- self.assertEqual(
- mock_popen.mock_calls[0],
- call(
- [
- 'spark-sql',
- '-e',
- 'SELECT 1',
- '--master',
- 'yarn',
- '--name',
- 'default-name',
- '--verbose',
- '--queue',
- 'default',
- '--deploy-mode',
- 'cluster',
- ],
- stderr=-2,
- stdout=-1,
- ),
+ assert mock_popen.mock_calls[0] == call(
+ [
+ 'spark-sql',
+ '-e',
+ 'SELECT 1',
+ '--master',
+ 'yarn',
+ '--name',
+ 'default-name',
+ '--verbose',
+ '--queue',
+ 'default',
+ '--deploy-mode',
+ 'cluster',
+ ],
+ stderr=-2,
+ stdout=-1,
)
@patch('airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen')
@@ -166,26 +162,23 @@ def test_spark_process_runcmd_with_list(self, mock_popen):
hook.run_query(['--deploy-mode', 'cluster'])
# Then
- self.assertEqual(
- mock_popen.mock_calls[0],
- call(
- [
- 'spark-sql',
- '-e',
- 'SELECT 1',
- '--master',
- 'yarn',
- '--name',
- 'default-name',
- '--verbose',
- '--queue',
- 'default',
- '--deploy-mode',
- 'cluster',
- ],
- stderr=-2,
- stdout=-1,
- ),
+ assert mock_popen.mock_calls[0] == call(
+ [
+ 'spark-sql',
+ '-e',
+ 'SELECT 1',
+ '--master',
+ 'yarn',
+ '--name',
+ 'default-name',
+ '--verbose',
+ '--queue',
+ 'default',
+ '--deploy-mode',
+ 'cluster',
+ ],
+ stderr=-2,
+ stdout=-1,
)
@patch('airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen')
@@ -198,7 +191,7 @@ def test_spark_process_runcmd_and_fail(self, mock_popen):
mock_popen.return_value.wait.return_value = status
# When
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
hook = SparkSqlHook(
conn_id='spark_default',
sql=sql,
@@ -207,9 +200,8 @@ def test_spark_process_runcmd_and_fail(self, mock_popen):
hook.run_query(params)
# Then
- self.assertEqual(
- str(e.exception),
- "Cannot execute '{}' on {} (additional parameters: '{}'). Process exit code: {}.".format(
- sql, master, params, status
- ),
+ assert str(
+ ctx.value
+ ) == "Cannot execute '{}' on {} (additional parameters: '{}'). Process exit code: {}.".format(
+ sql, master, params, status
)
diff --git a/tests/providers/apache/spark/hooks/test_spark_submit.py b/tests/providers/apache/spark/hooks/test_spark_submit.py
index 0ccbb5a36988e..eb0a3bf9bf562 100644
--- a/tests/providers/apache/spark/hooks/test_spark_submit.py
+++ b/tests/providers/apache/spark/hooks/test_spark_submit.py
@@ -21,6 +21,7 @@
import unittest
from unittest.mock import call, patch
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -197,7 +198,7 @@ def test_build_spark_submit_command(self):
'args should keep embdedded spaces',
'baz',
]
- self.assertEqual(expected_build_cmd, cmd)
+ assert expected_build_cmd == cmd
def test_build_track_driver_status_command(self):
# note this function is only relevant for spark setup matching below condition
@@ -247,15 +248,12 @@ def test_spark_process_runcmd(self, mock_popen):
hook.submit()
# Then
- self.assertEqual(
- mock_popen.mock_calls[0],
- call(
- ['spark-submit', '--master', 'yarn', '--name', 'default-name', ''],
- stderr=-2,
- stdout=-1,
- universal_newlines=True,
- bufsize=-1,
- ),
+ assert mock_popen.mock_calls[0] == call(
+ ['spark-submit', '--master', 'yarn', '--name', 'default-name', ''],
+ stderr=-2,
+ stdout=-1,
+ universal_newlines=True,
+ bufsize=-1,
)
def test_resolve_should_track_driver_status(self):
@@ -296,15 +294,15 @@ def test_resolve_should_track_driver_status(self):
)
# Then
- self.assertEqual(should_track_driver_status_default, False)
- self.assertEqual(should_track_driver_status_spark_yarn_cluster, False)
- self.assertEqual(should_track_driver_status_spark_k8s_cluster, False)
- self.assertEqual(should_track_driver_status_spark_default_mesos, False)
- self.assertEqual(should_track_driver_status_spark_home_set, False)
- self.assertEqual(should_track_driver_status_spark_home_not_set, False)
- self.assertEqual(should_track_driver_status_spark_binary_set, False)
- self.assertEqual(should_track_driver_status_spark_binary_and_home_set, False)
- self.assertEqual(should_track_driver_status_spark_standalone_cluster, True)
+ assert should_track_driver_status_default is False
+ assert should_track_driver_status_spark_yarn_cluster is False
+ assert should_track_driver_status_spark_k8s_cluster is False
+ assert should_track_driver_status_spark_default_mesos is False
+ assert should_track_driver_status_spark_home_set is False
+ assert should_track_driver_status_spark_home_not_set is False
+ assert should_track_driver_status_spark_binary_set is False
+ assert should_track_driver_status_spark_binary_and_home_set is False
+ assert should_track_driver_status_spark_standalone_cluster is True
def test_resolve_connection_yarn_default(self):
# Given
@@ -324,8 +322,8 @@ def test_resolve_connection_yarn_default(self):
"spark_home": None,
"namespace": None,
}
- self.assertEqual(connection, expected_spark_connection)
- self.assertEqual(dict_cmd["--master"], "yarn")
+ assert connection == expected_spark_connection
+ assert dict_cmd["--master"] == "yarn"
def test_resolve_connection_yarn_default_connection(self):
# Given
@@ -345,9 +343,9 @@ def test_resolve_connection_yarn_default_connection(self):
"spark_home": None,
"namespace": None,
}
- self.assertEqual(connection, expected_spark_connection)
- self.assertEqual(dict_cmd["--master"], "yarn")
- self.assertEqual(dict_cmd["--queue"], "root.default")
+ assert connection == expected_spark_connection
+ assert dict_cmd["--master"] == "yarn"
+ assert dict_cmd["--queue"] == "root.default"
def test_resolve_connection_mesos_default_connection(self):
# Given
@@ -367,8 +365,8 @@ def test_resolve_connection_mesos_default_connection(self):
"spark_home": None,
"namespace": None,
}
- self.assertEqual(connection, expected_spark_connection)
- self.assertEqual(dict_cmd["--master"], "mesos://host:5050")
+ assert connection == expected_spark_connection
+ assert dict_cmd["--master"] == "mesos://host:5050"
def test_resolve_connection_spark_yarn_cluster_connection(self):
# Given
@@ -388,10 +386,10 @@ def test_resolve_connection_spark_yarn_cluster_connection(self):
"spark_home": None,
"namespace": None,
}
- self.assertEqual(connection, expected_spark_connection)
- self.assertEqual(dict_cmd["--master"], "yarn://yarn-master")
- self.assertEqual(dict_cmd["--queue"], "root.etl")
- self.assertEqual(dict_cmd["--deploy-mode"], "cluster")
+ assert connection == expected_spark_connection
+ assert dict_cmd["--master"] == "yarn://yarn-master"
+ assert dict_cmd["--queue"] == "root.etl"
+ assert dict_cmd["--deploy-mode"] == "cluster"
def test_resolve_connection_spark_k8s_cluster_connection(self):
# Given
@@ -411,9 +409,9 @@ def test_resolve_connection_spark_k8s_cluster_connection(self):
"deploy_mode": "cluster",
"namespace": "mynamespace",
}
- self.assertEqual(connection, expected_spark_connection)
- self.assertEqual(dict_cmd["--master"], "k8s://https://k8s-master")
- self.assertEqual(dict_cmd["--deploy-mode"], "cluster")
+ assert connection == expected_spark_connection
+ assert dict_cmd["--master"] == "k8s://https://k8s-master"
+ assert dict_cmd["--deploy-mode"] == "cluster"
def test_resolve_connection_spark_k8s_cluster_ns_conf(self):
# Given we specify the config option directly
@@ -436,10 +434,10 @@ def test_resolve_connection_spark_k8s_cluster_ns_conf(self):
"deploy_mode": "cluster",
"namespace": "airflow",
}
- self.assertEqual(connection, expected_spark_connection)
- self.assertEqual(dict_cmd["--master"], "k8s://https://k8s-master")
- self.assertEqual(dict_cmd["--deploy-mode"], "cluster")
- self.assertEqual(dict_cmd["--conf"], "spark.kubernetes.namespace=airflow")
+ assert connection == expected_spark_connection
+ assert dict_cmd["--master"] == "k8s://https://k8s-master"
+ assert dict_cmd["--deploy-mode"] == "cluster"
+ assert dict_cmd["--conf"] == "spark.kubernetes.namespace=airflow"
def test_resolve_connection_spark_home_set_connection(self):
# Given
@@ -458,8 +456,8 @@ def test_resolve_connection_spark_home_set_connection(self):
"spark_home": "/opt/myspark",
"namespace": None,
}
- self.assertEqual(connection, expected_spark_connection)
- self.assertEqual(cmd[0], '/opt/myspark/bin/spark-submit')
+ assert connection == expected_spark_connection
+ assert cmd[0] == '/opt/myspark/bin/spark-submit'
def test_resolve_connection_spark_home_not_set_connection(self):
# Given
@@ -478,8 +476,8 @@ def test_resolve_connection_spark_home_not_set_connection(self):
"spark_home": None,
"namespace": None,
}
- self.assertEqual(connection, expected_spark_connection)
- self.assertEqual(cmd[0], 'spark-submit')
+ assert connection == expected_spark_connection
+ assert cmd[0] == 'spark-submit'
def test_resolve_connection_spark_binary_set_connection(self):
# Given
@@ -498,8 +496,8 @@ def test_resolve_connection_spark_binary_set_connection(self):
"spark_home": None,
"namespace": None,
}
- self.assertEqual(connection, expected_spark_connection)
- self.assertEqual(cmd[0], 'custom-spark-submit')
+ assert connection == expected_spark_connection
+ assert cmd[0] == 'custom-spark-submit'
def test_resolve_connection_spark_binary_default_value_override(self):
# Given
@@ -518,8 +516,8 @@ def test_resolve_connection_spark_binary_default_value_override(self):
"spark_home": None,
"namespace": None,
}
- self.assertEqual(connection, expected_spark_connection)
- self.assertEqual(cmd[0], 'another-custom-spark-submit')
+ assert connection == expected_spark_connection
+ assert cmd[0] == 'another-custom-spark-submit'
def test_resolve_connection_spark_binary_default_value(self):
# Given
@@ -538,8 +536,8 @@ def test_resolve_connection_spark_binary_default_value(self):
"spark_home": None,
"namespace": None,
}
- self.assertEqual(connection, expected_spark_connection)
- self.assertEqual(cmd[0], 'spark-submit')
+ assert connection == expected_spark_connection
+ assert cmd[0] == 'spark-submit'
def test_resolve_connection_spark_binary_and_home_set_connection(self):
# Given
@@ -558,8 +556,8 @@ def test_resolve_connection_spark_binary_and_home_set_connection(self):
"spark_home": "/path/to/spark_home",
"namespace": None,
}
- self.assertEqual(connection, expected_spark_connection)
- self.assertEqual(cmd[0], '/path/to/spark_home/bin/custom-spark-submit')
+ assert connection == expected_spark_connection
+ assert cmd[0] == '/path/to/spark_home/bin/custom-spark-submit'
def test_resolve_connection_spark_standalone_cluster_connection(self):
# Given
@@ -578,8 +576,8 @@ def test_resolve_connection_spark_standalone_cluster_connection(self):
"spark_home": "/path/to/spark_home",
"namespace": None,
}
- self.assertEqual(connection, expected_spark_connection)
- self.assertEqual(cmd[0], '/path/to/spark_home/bin/spark-submit')
+ assert connection == expected_spark_connection
+ assert cmd[0] == '/path/to/spark_home/bin/spark-submit'
def test_resolve_spark_submit_env_vars_standalone_client_mode(self):
# Given
@@ -589,7 +587,7 @@ def test_resolve_spark_submit_env_vars_standalone_client_mode(self):
hook._build_spark_submit_command(self._spark_job_file)
# Then
- self.assertEqual(hook._env, {"bar": "foo"})
+ assert hook._env == {"bar": "foo"}
def test_resolve_spark_submit_env_vars_standalone_cluster_mode(self):
def env_vars_exception_in_standalone_cluster_mode():
@@ -600,7 +598,8 @@ def env_vars_exception_in_standalone_cluster_mode():
hook._build_spark_submit_command(self._spark_job_file)
# Then
- self.assertRaises(AirflowException, env_vars_exception_in_standalone_cluster_mode)
+ with pytest.raises(AirflowException):
+ env_vars_exception_in_standalone_cluster_mode()
def test_resolve_spark_submit_env_vars_yarn(self):
# Given
@@ -610,8 +609,8 @@ def test_resolve_spark_submit_env_vars_yarn(self):
cmd = hook._build_spark_submit_command(self._spark_job_file)
# Then
- self.assertEqual(cmd[4], "spark.yarn.appMasterEnv.bar=foo")
- self.assertEqual(hook._env, {"bar": "foo"})
+ assert cmd[4] == "spark.yarn.appMasterEnv.bar=foo"
+ assert hook._env == {"bar": "foo"}
def test_resolve_spark_submit_env_vars_k8s(self):
# Given
@@ -621,7 +620,7 @@ def test_resolve_spark_submit_env_vars_k8s(self):
cmd = hook._build_spark_submit_command(self._spark_job_file)
# Then
- self.assertEqual(cmd[4], "spark.kubernetes.driverEnv.bar=foo")
+ assert cmd[4] == "spark.kubernetes.driverEnv.bar=foo"
def test_process_spark_submit_log_yarn(self):
# Given
@@ -640,7 +639,7 @@ def test_process_spark_submit_log_yarn(self):
# Then
- self.assertEqual(hook._yarn_application_id, 'application_1486558679801_1820')
+ assert hook._yarn_application_id == 'application_1486558679801_1820'
def test_process_spark_submit_log_k8s(self):
# Given
@@ -672,8 +671,8 @@ def test_process_spark_submit_log_k8s(self):
hook._process_spark_submit_log(log_lines)
# Then
- self.assertEqual(hook._kubernetes_driver_pod, 'spark-pi-edf2ace37be7353a958b38733a12f8e6-driver')
- self.assertEqual(hook._spark_exit_code, 999)
+ assert hook._kubernetes_driver_pod == 'spark-pi-edf2ace37be7353a958b38733a12f8e6-driver'
+ assert hook._spark_exit_code == 999
def test_process_spark_submit_log_k8s_spark_3(self):
# Given
@@ -684,7 +683,7 @@ def test_process_spark_submit_log_k8s_spark_3(self):
hook._process_spark_submit_log(log_lines)
# Then
- self.assertEqual(hook._spark_exit_code, 999)
+ assert hook._spark_exit_code == 999
def test_process_spark_submit_log_standalone_cluster(self):
# Given
@@ -701,7 +700,7 @@ def test_process_spark_submit_log_standalone_cluster(self):
# Then
- self.assertEqual(hook._driver_id, 'driver-20171128111415-0001')
+ assert hook._driver_id == 'driver-20171128111415-0001'
def test_process_spark_driver_status_log(self):
# Given
@@ -726,7 +725,7 @@ def test_process_spark_driver_status_log(self):
# Then
- self.assertEqual(hook._driver_status, 'RUNNING')
+ assert hook._driver_status == 'RUNNING'
@patch('airflow.providers.apache.spark.hooks.spark_submit.renew_from_kt')
@patch('airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen')
@@ -754,14 +753,14 @@ def test_yarn_process_on_kill(self, mock_popen, mock_renew_from_kt):
hook.on_kill()
# Then
- self.assertIn(
+ assert (
call(
['yarn', 'application', '-kill', 'application_1486558679801_1820'],
env=None,
stderr=-1,
stdout=-1,
- ),
- mock_popen.mock_calls,
+ )
+ in mock_popen.mock_calls
)
# resetting the mock to test kill with keytab & principal
mock_popen.reset_mock()
@@ -777,14 +776,14 @@ def test_yarn_process_on_kill(self, mock_popen, mock_renew_from_kt):
# Then
expected_env = os.environ.copy()
expected_env["KRB5CCNAME"] = '/tmp/airflow_krb5_ccache'
- self.assertIn(
+ assert (
call(
['yarn', 'application', '-kill', 'application_1486558679801_1820'],
env=expected_env,
stderr=-1,
stdout=-1,
- ),
- mock_popen.mock_calls,
+ )
+ in mock_popen.mock_calls
)
def test_standalone_cluster_process_on_kill(self):
@@ -803,11 +802,11 @@ def test_standalone_cluster_process_on_kill(self):
kill_cmd = hook._build_spark_driver_kill_command()
# Then
- self.assertEqual(kill_cmd[0], '/path/to/spark_home/bin/spark-submit')
- self.assertEqual(kill_cmd[1], '--master')
- self.assertEqual(kill_cmd[2], 'spark://spark-standalone-master:6066')
- self.assertEqual(kill_cmd[3], '--kill')
- self.assertEqual(kill_cmd[4], 'driver-20171128111415-0001')
+ assert kill_cmd[0] == '/path/to/spark_home/bin/spark-submit'
+ assert kill_cmd[1] == '--master'
+ assert kill_cmd[2] == 'spark://spark-standalone-master:6066'
+ assert kill_cmd[3] == '--kill'
+ assert kill_cmd[4] == 'driver-20171128111415-0001'
@patch('airflow.kubernetes.kube_client.get_kube_client')
@patch('airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen')
diff --git a/tests/providers/apache/spark/operators/test_spark_jdbc.py b/tests/providers/apache/spark/operators/test_spark_jdbc.py
index 25143adf28499..f8c1a564aca22 100644
--- a/tests/providers/apache/spark/operators/test_spark_jdbc.py
+++ b/tests/providers/apache/spark/operators/test_spark_jdbc.py
@@ -99,31 +99,31 @@ def test_execute(self):
'comments VARCHAR(1024)',
}
- self.assertEqual(spark_conn_id, operator._spark_conn_id)
- self.assertEqual(jdbc_conn_id, operator._jdbc_conn_id)
- self.assertEqual(expected_dict['spark_app_name'], operator._spark_app_name)
- self.assertEqual(expected_dict['spark_conf'], operator._spark_conf)
- self.assertEqual(expected_dict['spark_files'], operator._spark_files)
- self.assertEqual(expected_dict['spark_py_files'], operator._spark_py_files)
- self.assertEqual(expected_dict['spark_jars'], operator._spark_jars)
- self.assertEqual(expected_dict['num_executors'], operator._num_executors)
- self.assertEqual(expected_dict['executor_cores'], operator._executor_cores)
- self.assertEqual(expected_dict['executor_memory'], operator._executor_memory)
- self.assertEqual(expected_dict['driver_memory'], operator._driver_memory)
- self.assertEqual(expected_dict['verbose'], operator._verbose)
- self.assertEqual(expected_dict['keytab'], operator._keytab)
- self.assertEqual(expected_dict['principal'], operator._principal)
- self.assertEqual(expected_dict['cmd_type'], operator._cmd_type)
- self.assertEqual(expected_dict['jdbc_table'], operator._jdbc_table)
- self.assertEqual(expected_dict['jdbc_driver'], operator._jdbc_driver)
- self.assertEqual(expected_dict['metastore_table'], operator._metastore_table)
- self.assertEqual(expected_dict['jdbc_truncate'], operator._jdbc_truncate)
- self.assertEqual(expected_dict['save_mode'], operator._save_mode)
- self.assertEqual(expected_dict['save_format'], operator._save_format)
- self.assertEqual(expected_dict['batch_size'], operator._batch_size)
- self.assertEqual(expected_dict['fetch_size'], operator._fetch_size)
- self.assertEqual(expected_dict['num_partitions'], operator._num_partitions)
- self.assertEqual(expected_dict['partition_column'], operator._partition_column)
- self.assertEqual(expected_dict['lower_bound'], operator._lower_bound)
- self.assertEqual(expected_dict['upper_bound'], operator._upper_bound)
- self.assertEqual(expected_dict['create_table_column_types'], operator._create_table_column_types)
+ assert spark_conn_id == operator._spark_conn_id
+ assert jdbc_conn_id == operator._jdbc_conn_id
+ assert expected_dict['spark_app_name'] == operator._spark_app_name
+ assert expected_dict['spark_conf'] == operator._spark_conf
+ assert expected_dict['spark_files'] == operator._spark_files
+ assert expected_dict['spark_py_files'] == operator._spark_py_files
+ assert expected_dict['spark_jars'] == operator._spark_jars
+ assert expected_dict['num_executors'] == operator._num_executors
+ assert expected_dict['executor_cores'] == operator._executor_cores
+ assert expected_dict['executor_memory'] == operator._executor_memory
+ assert expected_dict['driver_memory'] == operator._driver_memory
+ assert expected_dict['verbose'] == operator._verbose
+ assert expected_dict['keytab'] == operator._keytab
+ assert expected_dict['principal'] == operator._principal
+ assert expected_dict['cmd_type'] == operator._cmd_type
+ assert expected_dict['jdbc_table'] == operator._jdbc_table
+ assert expected_dict['jdbc_driver'] == operator._jdbc_driver
+ assert expected_dict['metastore_table'] == operator._metastore_table
+ assert expected_dict['jdbc_truncate'] == operator._jdbc_truncate
+ assert expected_dict['save_mode'] == operator._save_mode
+ assert expected_dict['save_format'] == operator._save_format
+ assert expected_dict['batch_size'] == operator._batch_size
+ assert expected_dict['fetch_size'] == operator._fetch_size
+ assert expected_dict['num_partitions'] == operator._num_partitions
+ assert expected_dict['partition_column'] == operator._partition_column
+ assert expected_dict['lower_bound'] == operator._lower_bound
+ assert expected_dict['upper_bound'] == operator._upper_bound
+ assert expected_dict['create_table_column_types'] == operator._create_table_column_types
diff --git a/tests/providers/apache/spark/operators/test_spark_sql.py b/tests/providers/apache/spark/operators/test_spark_sql.py
index a282e0170afc3..6d1e664eb1bb2 100644
--- a/tests/providers/apache/spark/operators/test_spark_sql.py
+++ b/tests/providers/apache/spark/operators/test_spark_sql.py
@@ -50,18 +50,18 @@ def test_execute(self):
# Given / When
operator = SparkSqlOperator(task_id='spark_sql_job', dag=self.dag, **self._config)
- self.assertEqual(self._config['sql'], operator._sql)
- self.assertEqual(self._config['conn_id'], operator._conn_id)
- self.assertEqual(self._config['total_executor_cores'], operator._total_executor_cores)
- self.assertEqual(self._config['executor_cores'], operator._executor_cores)
- self.assertEqual(self._config['executor_memory'], operator._executor_memory)
- self.assertEqual(self._config['keytab'], operator._keytab)
- self.assertEqual(self._config['principal'], operator._principal)
- self.assertEqual(self._config['executor_memory'], operator._executor_memory)
- self.assertEqual(self._config['keytab'], operator._keytab)
- self.assertEqual(self._config['principal'], operator._principal)
- self.assertEqual(self._config['master'], operator._master)
- self.assertEqual(self._config['name'], operator._name)
- self.assertEqual(self._config['num_executors'], operator._num_executors)
- self.assertEqual(self._config['verbose'], operator._verbose)
- self.assertEqual(self._config['yarn_queue'], operator._yarn_queue)
+ assert self._config['sql'] == operator._sql
+ assert self._config['conn_id'] == operator._conn_id
+ assert self._config['total_executor_cores'] == operator._total_executor_cores
+ assert self._config['executor_cores'] == operator._executor_cores
+ assert self._config['executor_memory'] == operator._executor_memory
+ assert self._config['keytab'] == operator._keytab
+ assert self._config['principal'] == operator._principal
+ assert self._config['executor_memory'] == operator._executor_memory
+ assert self._config['keytab'] == operator._keytab
+ assert self._config['principal'] == operator._principal
+ assert self._config['master'] == operator._master
+ assert self._config['name'] == operator._name
+ assert self._config['num_executors'] == operator._num_executors
+ assert self._config['verbose'] == operator._verbose
+ assert self._config['yarn_queue'] == operator._yarn_queue
diff --git a/tests/providers/apache/spark/operators/test_spark_submit.py b/tests/providers/apache/spark/operators/test_spark_submit.py
index 4dc9a5fca01c7..fdbf3e4961c79 100644
--- a/tests/providers/apache/spark/operators/test_spark_submit.py
+++ b/tests/providers/apache/spark/operators/test_spark_submit.py
@@ -118,31 +118,31 @@ def test_execute(self):
'spark_binary': 'sparky',
}
- self.assertEqual(conn_id, operator._conn_id)
- self.assertEqual(expected_dict['application'], operator._application)
- self.assertEqual(expected_dict['conf'], operator._conf)
- self.assertEqual(expected_dict['files'], operator._files)
- self.assertEqual(expected_dict['py_files'], operator._py_files)
- self.assertEqual(expected_dict['archives'], operator._archives)
- self.assertEqual(expected_dict['driver_class_path'], operator._driver_class_path)
- self.assertEqual(expected_dict['jars'], operator._jars)
- self.assertEqual(expected_dict['packages'], operator._packages)
- self.assertEqual(expected_dict['exclude_packages'], operator._exclude_packages)
- self.assertEqual(expected_dict['repositories'], operator._repositories)
- self.assertEqual(expected_dict['total_executor_cores'], operator._total_executor_cores)
- self.assertEqual(expected_dict['executor_cores'], operator._executor_cores)
- self.assertEqual(expected_dict['executor_memory'], operator._executor_memory)
- self.assertEqual(expected_dict['keytab'], operator._keytab)
- self.assertEqual(expected_dict['principal'], operator._principal)
- self.assertEqual(expected_dict['proxy_user'], operator._proxy_user)
- self.assertEqual(expected_dict['name'], operator._name)
- self.assertEqual(expected_dict['num_executors'], operator._num_executors)
- self.assertEqual(expected_dict['status_poll_interval'], operator._status_poll_interval)
- self.assertEqual(expected_dict['verbose'], operator._verbose)
- self.assertEqual(expected_dict['java_class'], operator._java_class)
- self.assertEqual(expected_dict['driver_memory'], operator._driver_memory)
- self.assertEqual(expected_dict['application_args'], operator._application_args)
- self.assertEqual(expected_dict['spark_binary'], operator._spark_binary)
+ assert conn_id == operator._conn_id
+ assert expected_dict['application'] == operator._application
+ assert expected_dict['conf'] == operator._conf
+ assert expected_dict['files'] == operator._files
+ assert expected_dict['py_files'] == operator._py_files
+ assert expected_dict['archives'] == operator._archives
+ assert expected_dict['driver_class_path'] == operator._driver_class_path
+ assert expected_dict['jars'] == operator._jars
+ assert expected_dict['packages'] == operator._packages
+ assert expected_dict['exclude_packages'] == operator._exclude_packages
+ assert expected_dict['repositories'] == operator._repositories
+ assert expected_dict['total_executor_cores'] == operator._total_executor_cores
+ assert expected_dict['executor_cores'] == operator._executor_cores
+ assert expected_dict['executor_memory'] == operator._executor_memory
+ assert expected_dict['keytab'] == operator._keytab
+ assert expected_dict['principal'] == operator._principal
+ assert expected_dict['proxy_user'] == operator._proxy_user
+ assert expected_dict['name'] == operator._name
+ assert expected_dict['num_executors'] == operator._num_executors
+ assert expected_dict['status_poll_interval'] == operator._status_poll_interval
+ assert expected_dict['verbose'] == operator._verbose
+ assert expected_dict['java_class'] == operator._java_class
+ assert expected_dict['driver_memory'] == operator._driver_memory
+ assert expected_dict['application_args'] == operator._application_args
+ assert expected_dict['spark_binary'] == operator._spark_binary
def test_render_template(self):
# Given
@@ -166,5 +166,5 @@ def test_render_template(self):
'args should keep embdedded spaces',
]
expected_name = 'spark_submit_job'
- self.assertListEqual(expected_application_args, getattr(operator, '_application_args'))
- self.assertEqual(expected_name, getattr(operator, '_name'))
+ assert expected_application_args == getattr(operator, '_application_args')
+ assert expected_name == getattr(operator, '_name')
diff --git a/tests/providers/apache/sqoop/hooks/test_sqoop.py b/tests/providers/apache/sqoop/hooks/test_sqoop.py
index 5b89eb867a0f6..332021aeb7a12 100644
--- a/tests/providers/apache/sqoop/hooks/test_sqoop.py
+++ b/tests/providers/apache/sqoop/hooks/test_sqoop.py
@@ -23,6 +23,8 @@
from io import StringIO
from unittest.mock import call, patch
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.apache.sqoop.hooks.sqoop import SqoopHook
@@ -106,57 +108,54 @@ def test_popen(self, mock_popen):
hook.export_table(**self._config_export)
# Then
- self.assertEqual(
- mock_popen.mock_calls[0],
- call(
- [
- 'sqoop',
- 'export',
- '-fs',
- self._config_json['namenode'],
- '-jt',
- self._config_json['job_tracker'],
- '-libjars',
- self._config_json['libjars'],
- '-files',
- self._config_json['files'],
- '-archives',
- self._config_json['archives'],
- '--connect',
- 'rmdbs:5050/schema',
- '--input-null-string',
- self._config_export['input_null_string'],
- '--input-null-non-string',
- self._config_export['input_null_non_string'],
- '--staging-table',
- self._config_export['staging_table'],
- '--clear-staging-table',
- '--enclosed-by',
- self._config_export['enclosed_by'],
- '--escaped-by',
- self._config_export['escaped_by'],
- '--input-fields-terminated-by',
- self._config_export['input_fields_terminated_by'],
- '--input-lines-terminated-by',
- self._config_export['input_lines_terminated_by'],
- '--input-optionally-enclosed-by',
- self._config_export['input_optionally_enclosed_by'],
- '--batch',
- '--relaxed-isolation',
- '--export-dir',
- self._config_export['export_dir'],
- '--update-key',
- 'id',
- '--update-mode',
- 'allowinsert',
- '--fetch-size',
- str(self._config_export['extra_export_options'].get('fetch-size')),
- '--table',
- self._config_export['table'],
- ],
- stderr=-2,
- stdout=-1,
- ),
+ assert mock_popen.mock_calls[0] == call(
+ [
+ 'sqoop',
+ 'export',
+ '-fs',
+ self._config_json['namenode'],
+ '-jt',
+ self._config_json['job_tracker'],
+ '-libjars',
+ self._config_json['libjars'],
+ '-files',
+ self._config_json['files'],
+ '-archives',
+ self._config_json['archives'],
+ '--connect',
+ 'rmdbs:5050/schema',
+ '--input-null-string',
+ self._config_export['input_null_string'],
+ '--input-null-non-string',
+ self._config_export['input_null_non_string'],
+ '--staging-table',
+ self._config_export['staging_table'],
+ '--clear-staging-table',
+ '--enclosed-by',
+ self._config_export['enclosed_by'],
+ '--escaped-by',
+ self._config_export['escaped_by'],
+ '--input-fields-terminated-by',
+ self._config_export['input_fields_terminated_by'],
+ '--input-lines-terminated-by',
+ self._config_export['input_lines_terminated_by'],
+ '--input-optionally-enclosed-by',
+ self._config_export['input_optionally_enclosed_by'],
+ '--batch',
+ '--relaxed-isolation',
+ '--export-dir',
+ self._config_export['export_dir'],
+ '--update-key',
+ 'id',
+ '--update-mode',
+ 'allowinsert',
+ '--fetch-size',
+ str(self._config_export['extra_export_options'].get('fetch-size')),
+ '--table',
+ self._config_export['table'],
+ ],
+ stderr=-2,
+ stdout=-1,
)
def test_submit_none_mappers(self):
@@ -168,7 +167,7 @@ def test_submit_none_mappers(self):
hook = SqoopHook(**_config_without_mappers)
cmd = ' '.join(hook._prepare_command())
- self.assertNotIn('--num-mappers', cmd)
+ assert '--num-mappers' not in cmd
def test_submit(self):
"""
@@ -180,42 +179,42 @@ def test_submit(self):
# Check if the config has been extracted from the json
if self._config_json['namenode']:
- self.assertIn("-fs {}".format(self._config_json['namenode']), cmd)
+ assert "-fs {}".format(self._config_json['namenode']) in cmd
if self._config_json['job_tracker']:
- self.assertIn("-jt {}".format(self._config_json['job_tracker']), cmd)
+ assert "-jt {}".format(self._config_json['job_tracker']) in cmd
if self._config_json['libjars']:
- self.assertIn("-libjars {}".format(self._config_json['libjars']), cmd)
+ assert "-libjars {}".format(self._config_json['libjars']) in cmd
if self._config_json['files']:
- self.assertIn("-files {}".format(self._config_json['files']), cmd)
+ assert "-files {}".format(self._config_json['files']) in cmd
if self._config_json['archives']:
- self.assertIn("-archives {}".format(self._config_json['archives']), cmd)
+ assert "-archives {}".format(self._config_json['archives']) in cmd
- self.assertIn("--hcatalog-database {}".format(self._config['hcatalog_database']), cmd)
- self.assertIn("--hcatalog-table {}".format(self._config['hcatalog_table']), cmd)
+ assert "--hcatalog-database {}".format(self._config['hcatalog_database']) in cmd
+ assert "--hcatalog-table {}".format(self._config['hcatalog_table']) in cmd
# Check the regulator stuff passed by the default constructor
if self._config['verbose']:
- self.assertIn("--verbose", cmd)
+ assert "--verbose" in cmd
if self._config['num_mappers']:
- self.assertIn("--num-mappers {}".format(self._config['num_mappers']), cmd)
+ assert "--num-mappers {}".format(self._config['num_mappers']) in cmd
for key, value in self._config['properties'].items():
- self.assertIn(f"-D {key}={value}", cmd)
+ assert f"-D {key}={value}" in cmd
# We don't have the sqoop binary available, and this is hard to mock,
# so just accept an exception for now.
- with self.assertRaises(OSError):
+ with pytest.raises(OSError):
hook.export_table(**self._config_export)
- with self.assertRaises(OSError):
+ with pytest.raises(OSError):
hook.import_table(table='schema.table', target_dir='/sqoop/example/path')
- with self.assertRaises(OSError):
+ with pytest.raises(OSError):
hook.import_query(query='SELECT * FROM sometable', target_dir='/sqoop/example/path')
def test_export_cmd(self):
@@ -244,38 +243,38 @@ def test_export_cmd(self):
)
)
- self.assertIn("--input-null-string {}".format(self._config_export['input_null_string']), cmd)
- self.assertIn("--input-null-non-string {}".format(self._config_export['input_null_non_string']), cmd)
- self.assertIn("--staging-table {}".format(self._config_export['staging_table']), cmd)
- self.assertIn("--enclosed-by {}".format(self._config_export['enclosed_by']), cmd)
- self.assertIn("--escaped-by {}".format(self._config_export['escaped_by']), cmd)
- self.assertIn(
- "--input-fields-terminated-by {}".format(self._config_export['input_fields_terminated_by']), cmd
+ assert "--input-null-string {}".format(self._config_export['input_null_string']) in cmd
+ assert "--input-null-non-string {}".format(self._config_export['input_null_non_string']) in cmd
+ assert "--staging-table {}".format(self._config_export['staging_table']) in cmd
+ assert "--enclosed-by {}".format(self._config_export['enclosed_by']) in cmd
+ assert "--escaped-by {}".format(self._config_export['escaped_by']) in cmd
+ assert (
+ "--input-fields-terminated-by {}".format(self._config_export['input_fields_terminated_by']) in cmd
)
- self.assertIn(
- "--input-lines-terminated-by {}".format(self._config_export['input_lines_terminated_by']), cmd
+ assert (
+ "--input-lines-terminated-by {}".format(self._config_export['input_lines_terminated_by']) in cmd
)
- self.assertIn(
- "--input-optionally-enclosed-by {}".format(self._config_export['input_optionally_enclosed_by']),
- cmd,
+ assert (
+ "--input-optionally-enclosed-by {}".format(self._config_export['input_optionally_enclosed_by'])
+ in cmd
)
# these options are from the extra export options
- self.assertIn("--update-key id", cmd)
- self.assertIn("--update-mode allowinsert", cmd)
+ assert "--update-key id" in cmd
+ assert "--update-mode allowinsert" in cmd
if self._config_export['clear_staging_table']:
- self.assertIn("--clear-staging-table", cmd)
+ assert "--clear-staging-table" in cmd
if self._config_export['batch']:
- self.assertIn("--batch", cmd)
+ assert "--batch" in cmd
if self._config_export['relaxed_isolation']:
- self.assertIn("--relaxed-isolation", cmd)
+ assert "--relaxed-isolation" in cmd
if self._config_export['extra_export_options']:
- self.assertIn("--update-key", cmd)
- self.assertIn("--update-mode", cmd)
- self.assertIn("--fetch-size", cmd)
+ assert "--update-key" in cmd
+ assert "--update-mode" in cmd
+ assert "--fetch-size" in cmd
def test_import_cmd(self):
"""
@@ -297,18 +296,18 @@ def test_import_cmd(self):
)
if self._config_import['append']:
- self.assertIn('--append', cmd)
+ assert '--append' in cmd
if self._config_import['direct']:
- self.assertIn('--direct', cmd)
+ assert '--direct' in cmd
- self.assertIn('--target-dir {}'.format(self._config_import['target_dir']), cmd)
+ assert '--target-dir {}'.format(self._config_import['target_dir']) in cmd
- self.assertIn('--driver {}'.format(self._config_import['driver']), cmd)
- self.assertIn('--split-by {}'.format(self._config_import['split_by']), cmd)
+ assert '--driver {}'.format(self._config_import['driver']) in cmd
+ assert '--split-by {}'.format(self._config_import['split_by']) in cmd
# these are from extra options, but not passed to this cmd import command
- self.assertNotIn('--show', cmd)
- self.assertNotIn('hcatalog-storage-stanza \"stored as orcfile\"', cmd)
+ assert '--show' not in cmd
+ assert 'hcatalog-storage-stanza \"stored as orcfile\"' not in cmd
cmd = ' '.join(
hook._import_cmd(
@@ -322,11 +321,11 @@ def test_import_cmd(self):
)
)
- self.assertNotIn('--target-dir', cmd)
+ assert '--target-dir' not in cmd
# these checks are from the extra import options
- self.assertIn('--show', cmd)
- self.assertIn('hcatalog-storage-stanza \"stored as orcfile\"', cmd)
- self.assertIn('--fetch-size', cmd)
+ assert '--show' in cmd
+ assert 'hcatalog-storage-stanza \"stored as orcfile\"' in cmd
+ assert '--fetch-size' in cmd
def test_get_export_format_argument(self):
"""
@@ -334,11 +333,11 @@ def test_get_export_format_argument(self):
correct Sqoop command with correct format type.
"""
hook = SqoopHook()
- self.assertIn("--as-avrodatafile", hook._get_export_format_argument('avro'))
- self.assertIn("--as-parquetfile", hook._get_export_format_argument('parquet'))
- self.assertIn("--as-sequencefile", hook._get_export_format_argument('sequence'))
- self.assertIn("--as-textfile", hook._get_export_format_argument('text'))
- with self.assertRaises(AirflowException):
+ assert "--as-avrodatafile" in hook._get_export_format_argument('avro')
+ assert "--as-parquetfile" in hook._get_export_format_argument('parquet')
+ assert "--as-sequencefile" in hook._get_export_format_argument('sequence')
+ assert "--as-textfile" in hook._get_export_format_argument('text')
+ with pytest.raises(AirflowException):
hook._get_export_format_argument('unknown')
def test_cmd_mask_password(self):
@@ -346,7 +345,7 @@ def test_cmd_mask_password(self):
Tests to verify the hook masking function will correctly mask a user password in Sqoop command.
"""
hook = SqoopHook()
- self.assertEqual(hook.cmd_mask_password(['--password', 'supersecret']), ['--password', 'MASKED'])
+ assert hook.cmd_mask_password(['--password', 'supersecret']) == ['--password', 'MASKED']
cmd = ['--target', 'targettable']
- self.assertEqual(hook.cmd_mask_password(cmd), cmd)
+ assert hook.cmd_mask_password(cmd) == cmd
diff --git a/tests/providers/apache/sqoop/operators/test_sqoop.py b/tests/providers/apache/sqoop/operators/test_sqoop.py
index 295ca841e61fd..882d13a5ec351 100644
--- a/tests/providers/apache/sqoop/operators/test_sqoop.py
+++ b/tests/providers/apache/sqoop/operators/test_sqoop.py
@@ -20,6 +20,8 @@
import datetime
import unittest
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
from airflow.providers.apache.sqoop.operators.sqoop import SqoopOperator
@@ -69,29 +71,29 @@ def test_execute(self):
"""
operator = SqoopOperator(task_id='sqoop_job', dag=self.dag, **self._config)
- self.assertEqual(self._config['conn_id'], operator.conn_id)
- self.assertEqual(self._config['query'], operator.query)
- self.assertEqual(self._config['cmd_type'], operator.cmd_type)
- self.assertEqual(self._config['table'], operator.table)
- self.assertEqual(self._config['target_dir'], operator.target_dir)
- self.assertEqual(self._config['append'], operator.append)
- self.assertEqual(self._config['file_type'], operator.file_type)
- self.assertEqual(self._config['num_mappers'], operator.num_mappers)
- self.assertEqual(self._config['split_by'], operator.split_by)
- self.assertEqual(self._config['input_null_string'], operator.input_null_string)
- self.assertEqual(self._config['input_null_non_string'], operator.input_null_non_string)
- self.assertEqual(self._config['staging_table'], operator.staging_table)
- self.assertEqual(self._config['clear_staging_table'], operator.clear_staging_table)
- self.assertEqual(self._config['batch'], operator.batch)
- self.assertEqual(self._config['relaxed_isolation'], operator.relaxed_isolation)
- self.assertEqual(self._config['direct'], operator.direct)
- self.assertEqual(self._config['driver'], operator.driver)
- self.assertEqual(self._config['properties'], operator.properties)
- self.assertEqual(self._config['hcatalog_database'], operator.hcatalog_database)
- self.assertEqual(self._config['hcatalog_table'], operator.hcatalog_table)
- self.assertEqual(self._config['create_hcatalog_table'], operator.create_hcatalog_table)
- self.assertEqual(self._config['extra_import_options'], operator.extra_import_options)
- self.assertEqual(self._config['extra_export_options'], operator.extra_export_options)
+ assert self._config['conn_id'] == operator.conn_id
+ assert self._config['query'] == operator.query
+ assert self._config['cmd_type'] == operator.cmd_type
+ assert self._config['table'] == operator.table
+ assert self._config['target_dir'] == operator.target_dir
+ assert self._config['append'] == operator.append
+ assert self._config['file_type'] == operator.file_type
+ assert self._config['num_mappers'] == operator.num_mappers
+ assert self._config['split_by'] == operator.split_by
+ assert self._config['input_null_string'] == operator.input_null_string
+ assert self._config['input_null_non_string'] == operator.input_null_non_string
+ assert self._config['staging_table'] == operator.staging_table
+ assert self._config['clear_staging_table'] == operator.clear_staging_table
+ assert self._config['batch'] == operator.batch
+ assert self._config['relaxed_isolation'] == operator.relaxed_isolation
+ assert self._config['direct'] == operator.direct
+ assert self._config['driver'] == operator.driver
+ assert self._config['properties'] == operator.properties
+ assert self._config['hcatalog_database'] == operator.hcatalog_database
+ assert self._config['hcatalog_table'] == operator.hcatalog_table
+ assert self._config['create_hcatalog_table'] == operator.create_hcatalog_table
+ assert self._config['extra_import_options'] == operator.extra_import_options
+ assert self._config['extra_export_options'] == operator.extra_export_options
# the following are meant to be more of examples
SqoopOperator(
@@ -174,7 +176,7 @@ def test_invalid_cmd_type(self):
Tests to verify if the cmd_type is not import or export, an exception is raised.
"""
operator = SqoopOperator(task_id='sqoop_job', dag=self.dag, cmd_type='invalid')
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
operator.execute({})
def test_invalid_import_options(self):
@@ -184,5 +186,5 @@ def test_invalid_import_options(self):
import_query_and_table_configs = self._config.copy()
import_query_and_table_configs['cmd_type'] = 'import'
operator = SqoopOperator(task_id='sqoop_job', dag=self.dag, **import_query_and_table_configs)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
operator.execute({})
diff --git a/tests/providers/celery/sensors/test_celery_queue.py b/tests/providers/celery/sensors/test_celery_queue.py
index ba2b15903076c..4b3820802fa18 100644
--- a/tests/providers/celery/sensors/test_celery_queue.py
+++ b/tests/providers/celery/sensors/test_celery_queue.py
@@ -40,7 +40,7 @@ def test_poke_success(self, mock_inspect):
mock_inspect_result.active.return_value = {'test_queue': []}
test_sensor = self.sensor(celery_queue='test_queue', task_id='test-task')
- self.assertTrue(test_sensor.poke(None))
+ assert test_sensor.poke(None)
@patch('celery.app.control.Inspect')
def test_poke_fail(self, mock_inspect):
@@ -52,11 +52,11 @@ def test_poke_fail(self, mock_inspect):
mock_inspect_result.active.return_value = {'test_queue': ['task']}
test_sensor = self.sensor(celery_queue='test_queue', task_id='test-task')
- self.assertFalse(test_sensor.poke(None))
+ assert not test_sensor.poke(None)
@patch('celery.app.control.Inspect')
def test_poke_success_with_taskid(self, mock_inspect):
test_sensor = self.sensor(
celery_queue='test_queue', task_id='test-task', target_task_id='target-task'
)
- self.assertTrue(test_sensor.poke(None))
+ assert test_sensor.poke(None)
diff --git a/tests/providers/cloudant/hooks/test_cloudant.py b/tests/providers/cloudant/hooks/test_cloudant.py
index d0c78dbebb97c..427a92259c9c8 100644
--- a/tests/providers/cloudant/hooks/test_cloudant.py
+++ b/tests/providers/cloudant/hooks/test_cloudant.py
@@ -18,6 +18,8 @@
import unittest
from unittest.mock import patch
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.cloudant.hooks.cloudant import CloudantHook
@@ -37,12 +39,12 @@ def test_get_conn(self, mock_cloudant, mock_get_connection):
conn = mock_get_connection.return_value
mock_cloudant.assert_called_once_with(user=conn.login, passwd=conn.password, account=conn.host)
- self.assertEqual(cloudant_session, mock_cloudant.return_value)
+ assert cloudant_session == mock_cloudant.return_value
@patch(
'airflow.providers.cloudant.hooks.cloudant.CloudantHook.get_connection',
return_value=Connection(login='user'),
)
def test_get_conn_invalid_connection(self, mock_get_connection):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.cloudant_hook.get_conn()
diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
index 9a01d8fa5007b..c8c2de498209b 100644
--- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
@@ -24,6 +24,7 @@
from unittest.mock import patch
import kubernetes
+import pytest
from parameterized import parameterized
from airflow import AirflowException
@@ -77,7 +78,7 @@ def test_in_cluster_connection(self, mock_kube_config_loader):
kubernetes_hook = KubernetesHook(conn_id='kubernetes_in_cluster')
api_conn = kubernetes_hook.get_conn()
mock_kube_config_loader.assert_called_once()
- self.assertIsInstance(api_conn, kubernetes.client.api_client.ApiClient)
+ assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
@@ -86,7 +87,7 @@ def test_kube_config_path(self, mock_kube_config_loader, mock_kube_config_merger
api_conn = kubernetes_hook.get_conn()
mock_kube_config_loader.assert_called_once_with("path/to/file")
mock_kube_config_merger.assert_called_once()
- self.assertIsInstance(api_conn, kubernetes.client.api_client.ApiClient)
+ assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
@@ -97,7 +98,7 @@ def test_kube_config_connection(self, mock_kube_config_loader, mock_kube_config_
mock_tempfile.is_called_once()
mock_kube_config_loader.assert_called_once()
mock_kube_config_merger.assert_called_once()
- self.assertIsInstance(api_conn, kubernetes.client.api_client.ApiClient)
+ assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
@@ -111,13 +112,13 @@ def test_default_kube_config_connection(
api_conn = kubernetes_hook.get_conn()
mock_kube_config_loader.assert_called_once_with("/mock/config")
mock_kube_config_merger.assert_called_once()
- self.assertIsInstance(api_conn, kubernetes.client.api_client.ApiClient)
+ assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
def test_get_namespace(self):
kubernetes_hook_with_namespace = KubernetesHook(conn_id='kubernetes_with_namespace')
kubernetes_hook_without_namespace = KubernetesHook(conn_id='kubernetes_default_kube_config')
- self.assertEqual(kubernetes_hook_with_namespace.get_namespace(), 'mock_namespace')
- self.assertEqual(kubernetes_hook_without_namespace.get_namespace(), 'default')
+ assert kubernetes_hook_with_namespace.get_namespace() == 'mock_namespace'
+ assert kubernetes_hook_without_namespace.get_namespace() == 'default'
class TestKubernetesHookIncorrectConfiguration(unittest.TestCase):
@@ -129,8 +130,8 @@ class TestKubernetesHookIncorrectConfiguration(unittest.TestCase):
)
)
def test_should_raise_exception_on_invalid_configuration(self, conn_uri):
- with mock.patch.dict("os.environ", AIRFLOW_CONN_KUBERNETES_DEFAULT=conn_uri), self.assertRaisesRegex(
- AirflowException, "Invalid connection configuration"
+ with mock.patch.dict("os.environ", AIRFLOW_CONN_KUBERNETES_DEFAULT=conn_uri), pytest.raises(
+ AirflowException, match="Invalid connection configuration"
):
kubernetes_hook = KubernetesHook()
kubernetes_hook.get_conn()
diff --git a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py
index 016922b7186f7..b92a0c9cd99dc 100644
--- a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py
+++ b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py
@@ -19,6 +19,7 @@
from unittest import mock
import pendulum
+import pytest
from kubernetes.client import ApiClient, models as k8s
from airflow.exceptions import AirflowException
@@ -94,10 +95,9 @@ def test_image_pull_secrets_correctly_set(self, mock_client, monitor_mock, start
monitor_mock.return_value = (State.SUCCESS, None)
context = self.create_context(k)
k.execute(context=context)
- self.assertEqual(
- start_mock.call_args[0][0].spec.image_pull_secrets,
- [k8s.V1LocalObjectReference(name=fake_pull_secrets)],
- )
+ assert start_mock.call_args[0][0].spec.image_pull_secrets == [
+ k8s.V1LocalObjectReference(name=fake_pull_secrets)
+ ]
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@@ -120,10 +120,7 @@ def test_image_pull_policy_not_set(self, mock_client, monitor_mock, start_mock):
monitor_mock.return_value = (State.SUCCESS, None)
context = self.create_context(k)
k.execute(context=context)
- self.assertEqual(
- start_mock.call_args[0][0].spec.containers[0].image_pull_policy,
- 'IfNotPresent',
- )
+ assert start_mock.call_args[0][0].spec.containers[0].image_pull_policy == 'IfNotPresent'
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@@ -147,10 +144,7 @@ def test_image_pull_policy_correctly_set(self, mock_client, monitor_mock, start_
monitor_mock.return_value = (State.SUCCESS, None)
context = self.create_context(k)
k.execute(context=context)
- self.assertEqual(
- start_mock.call_args[0][0].spec.containers[0].image_pull_policy,
- 'Always',
- )
+ assert start_mock.call_args[0][0].spec.containers[0].image_pull_policy == 'Always'
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@@ -173,7 +167,7 @@ def test_pod_delete_even_on_launcher_error(
is_delete_operator_pod=True,
)
monitor_pod_mock.side_effect = AirflowException('fake failure')
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
context = self.create_context(k)
k.execute(context=context)
assert delete_pod_mock.called
@@ -187,9 +181,9 @@ def test_jinja_templated_fields(self):
task_id="task",
)
- self.assertEqual(task.image, "{{ image_jinja }}:16.04")
+ assert task.image == "{{ image_jinja }}:16.04"
task.render_template_fields(context={"image_jinja": "ubuntu"})
- self.assertEqual(task.image, "ubuntu:16.04")
+ assert task.image == "ubuntu:16.04"
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@@ -243,16 +237,16 @@ def test_describes_pod_on_failure(self, mock_client, monitor_mock, start_mock):
read_namespaced_pod_mock = mock_client.return_value.read_namespaced_pod
read_namespaced_pod_mock.return_value = failed_pod_status
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
context = self.create_context(k)
k.execute(context=context)
- self.assertEqual(
- str(cm.exception),
- f"Pod Launching failed: Pod {k.pod.metadata.name} returned a failure: {failed_pod_status}",
+ assert (
+ str(ctx.value)
+ == f"Pod Launching failed: Pod {k.pod.metadata.name} returned a failure: {failed_pod_status}"
)
assert mock_client.return_value.read_namespaced_pod.called
- self.assertEqual(read_namespaced_pod_mock.call_args[0][0], k.pod.metadata.name)
+ assert read_namespaced_pod_mock.call_args[0][0] == k.pod.metadata.name
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@@ -313,8 +307,8 @@ def test_create_with_affinity(self):
result = k.create_pod_request_obj()
client = ApiClient()
- self.assertEqual(type(result.spec.affinity), k8s.V1Affinity)
- self.assertEqual(client.sanitize_for_serialization(result)['spec']['affinity'], affinity)
+ assert isinstance(result.spec.affinity, k8s.V1Affinity)
+ assert client.sanitize_for_serialization(result)['spec']['affinity'] == affinity
k8s_api_affinity = k8s.V1Affinity(
node_affinity=k8s.V1NodeAffinity(
@@ -346,8 +340,8 @@ def test_create_with_affinity(self):
)
result = k.create_pod_request_obj()
- self.assertEqual(type(result.spec.affinity), k8s.V1Affinity)
- self.assertEqual(client.sanitize_for_serialization(result)['spec']['affinity'], affinity)
+ assert isinstance(result.spec.affinity, k8s.V1Affinity)
+ assert client.sanitize_for_serialization(result)['spec']['affinity'] == affinity
def test_tolerations(self):
k8s_api_tolerations = [k8s.V1Toleration(key="key", operator="Equal", value="value")]
@@ -370,8 +364,8 @@ def test_tolerations(self):
result = k.create_pod_request_obj()
client = ApiClient()
- self.assertEqual(type(result.spec.tolerations[0]), k8s.V1Toleration)
- self.assertEqual(client.sanitize_for_serialization(result)['spec']['tolerations'], tolerations)
+ assert isinstance(result.spec.tolerations[0], k8s.V1Toleration)
+ assert client.sanitize_for_serialization(result)['spec']['tolerations'] == tolerations
k = KubernetesPodOperator(
namespace='default',
@@ -388,8 +382,8 @@ def test_tolerations(self):
)
result = k.create_pod_request_obj()
- self.assertEqual(type(result.spec.tolerations[0]), k8s.V1Toleration)
- self.assertEqual(client.sanitize_for_serialization(result)['spec']['tolerations'], tolerations)
+ assert isinstance(result.spec.tolerations[0], k8s.V1Toleration)
+ assert client.sanitize_for_serialization(result)['spec']['tolerations'] == tolerations
def test_node_selector(self):
node_selector = {'beta.kubernetes.io/os': 'linux'}
@@ -410,8 +404,8 @@ def test_node_selector(self):
result = k.create_pod_request_obj()
client = ApiClient()
- self.assertEqual(type(result.spec.node_selector), dict)
- self.assertEqual(client.sanitize_for_serialization(result)['spec']['nodeSelector'], node_selector)
+ assert isinstance(result.spec.node_selector, dict)
+ assert client.sanitize_for_serialization(result)['spec']['nodeSelector'] == node_selector
# repeat tests using deprecated parameter
k = KubernetesPodOperator(
@@ -430,5 +424,5 @@ def test_node_selector(self):
result = k.create_pod_request_obj()
client = ApiClient()
- self.assertEqual(type(result.spec.node_selector), dict)
- self.assertEqual(client.sanitize_for_serialization(result)['spec']['nodeSelector'], node_selector)
+ assert isinstance(result.spec.node_selector, dict)
+ assert client.sanitize_for_serialization(result)['spec']['nodeSelector'] == node_selector
diff --git a/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py
index 437c48b146344..2bfb50e5dfd54 100644
--- a/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py
@@ -21,6 +21,7 @@
import unittest
from unittest.mock import patch
+import pytest
from kubernetes.client.rest import ApiException
from airflow import DAG
@@ -513,7 +514,7 @@ def setUp(self):
)
def test_completed_application(self, mock_get_namespaced_crd, mock_kubernetes_hook):
sensor = SparkKubernetesSensor(application_name="spark_pi", dag=self.dag, task_id="test_task_id")
- self.assertTrue(sensor.poke(None))
+ assert sensor.poke(None)
mock_kubernetes_hook.assert_called_once_with()
mock_get_namespaced_crd.assert_called_once_with(
group="sparkoperator.k8s.io",
@@ -529,7 +530,8 @@ def test_completed_application(self, mock_get_namespaced_crd, mock_kubernetes_ho
)
def test_failed_application(self, mock_get_namespaced_crd, mock_kubernetes_hook):
sensor = SparkKubernetesSensor(application_name="spark_pi", dag=self.dag, task_id="test_task_id")
- self.assertRaises(AirflowException, sensor.poke, None)
+ with pytest.raises(AirflowException):
+ sensor.poke(None)
mock_kubernetes_hook.assert_called_once_with()
mock_get_namespaced_crd.assert_called_once_with(
group="sparkoperator.k8s.io",
@@ -545,7 +547,7 @@ def test_failed_application(self, mock_get_namespaced_crd, mock_kubernetes_hook)
)
def test_not_processed_application(self, mock_get_namespaced_crd, mock_kubernetes_hook):
sensor = SparkKubernetesSensor(application_name="spark_pi", dag=self.dag, task_id="test_task_id")
- self.assertFalse(sensor.poke(None))
+ assert not sensor.poke(None)
mock_kubernetes_hook.assert_called_once_with()
mock_get_namespaced_crd.assert_called_once_with(
group="sparkoperator.k8s.io",
@@ -561,7 +563,7 @@ def test_not_processed_application(self, mock_get_namespaced_crd, mock_kubernete
)
def test_new_application(self, mock_get_namespaced_crd, mock_kubernetes_hook):
sensor = SparkKubernetesSensor(application_name="spark_pi", dag=self.dag, task_id="test_task_id")
- self.assertFalse(sensor.poke(None))
+ assert not sensor.poke(None)
mock_kubernetes_hook.assert_called_once_with()
mock_get_namespaced_crd.assert_called_once_with(
group="sparkoperator.k8s.io",
@@ -577,7 +579,7 @@ def test_new_application(self, mock_get_namespaced_crd, mock_kubernetes_hook):
)
def test_running_application(self, mock_get_namespaced_crd, mock_kubernetes_hook):
sensor = SparkKubernetesSensor(application_name="spark_pi", dag=self.dag, task_id="test_task_id")
- self.assertFalse(sensor.poke(None))
+ assert not sensor.poke(None)
mock_kubernetes_hook.assert_called_once_with()
mock_get_namespaced_crd.assert_called_once_with(
group="sparkoperator.k8s.io",
@@ -593,7 +595,7 @@ def test_running_application(self, mock_get_namespaced_crd, mock_kubernetes_hook
)
def test_submitted_application(self, mock_get_namespaced_crd, mock_kubernetes_hook):
sensor = SparkKubernetesSensor(application_name="spark_pi", dag=self.dag, task_id="test_task_id")
- self.assertFalse(sensor.poke(None))
+ assert not sensor.poke(None)
mock_kubernetes_hook.assert_called_once_with()
mock_get_namespaced_crd.assert_called_once_with(
group="sparkoperator.k8s.io",
@@ -609,7 +611,7 @@ def test_submitted_application(self, mock_get_namespaced_crd, mock_kubernetes_ho
)
def test_pending_rerun_application(self, mock_get_namespaced_crd, mock_kubernetes_hook):
sensor = SparkKubernetesSensor(application_name="spark_pi", dag=self.dag, task_id="test_task_id")
- self.assertFalse(sensor.poke(None))
+ assert not sensor.poke(None)
mock_kubernetes_hook.assert_called_once_with()
mock_get_namespaced_crd.assert_called_once_with(
group="sparkoperator.k8s.io",
@@ -625,7 +627,8 @@ def test_pending_rerun_application(self, mock_get_namespaced_crd, mock_kubernete
)
def test_unknown_application(self, mock_get_namespaced_crd, mock_kubernetes_hook):
sensor = SparkKubernetesSensor(application_name="spark_pi", dag=self.dag, task_id="test_task_id")
- self.assertRaises(AirflowException, sensor.poke, None)
+ with pytest.raises(AirflowException):
+ sensor.poke(None)
mock_kubernetes_hook.assert_called_once_with()
mock_get_namespaced_crd.assert_called_once_with(
group="sparkoperator.k8s.io",
@@ -696,7 +699,8 @@ def test_driver_logging_failure(
dag=self.dag,
task_id="test_task_id",
)
- self.assertRaises(AirflowException, sensor.poke, None)
+ with pytest.raises(AirflowException):
+ sensor.poke(None)
mock_log_call.assert_called_once_with("spark-pi-driver", namespace="default")
error_log_call.assert_called_once_with(TEST_POD_LOG_RESULT)
@@ -722,7 +726,7 @@ def test_driver_logging_completed(
mock_log_call.assert_called_once_with("spark-pi-2020-02-24-1-driver", namespace="default")
log_info_call = info_log_call.mock_calls[1]
log_value = log_info_call[1][0]
- self.assertEqual(log_value, TEST_POD_LOG_RESULT)
+ assert log_value == TEST_POD_LOG_RESULT
@patch(
"kubernetes.client.api.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object",
diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py
index 2c3f966a19c17..7eedbd2a6361f 100644
--- a/tests/providers/databricks/hooks/test_databricks.py
+++ b/tests/providers/databricks/hooks/test_databricks.py
@@ -22,6 +22,7 @@
import unittest
from unittest import mock
+import pytest
from requests import exceptions as requests_exceptions
from airflow import __version__
@@ -171,14 +172,14 @@ def setUp(self, session=None):
def test_parse_host_with_proper_host(self):
host = self.hook._parse_host(HOST)
- self.assertEqual(host, HOST)
+ assert host == HOST
def test_parse_host_with_scheme(self):
host = self.hook._parse_host(HOST_WITH_SCHEME)
- self.assertEqual(host, HOST)
+ assert host == HOST
def test_init_bad_retry_limit(self):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
DatabricksHook(retry_limit=0)
def test_do_api_call_retries_with_retryable_error(self):
@@ -193,17 +194,17 @@ def test_do_api_call_retries_with_retryable_error(self):
with mock.patch.object(self.hook.log, 'error') as mock_errors:
setup_mock_requests(mock_requests, exception)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
- self.assertEqual(mock_errors.call_count, self.hook.retry_limit)
+ assert mock_errors.call_count == self.hook.retry_limit
@mock.patch('airflow.providers.databricks.hooks.databricks.requests')
def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests):
setup_mock_requests(mock_requests, requests_exceptions.HTTPError, status_code=400)
with mock.patch.object(self.hook.log, 'error') as mock_errors:
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
mock_errors.assert_not_called()
@@ -224,8 +225,8 @@ def test_do_api_call_succeeds_after_retrying(self):
response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
- self.assertEqual(mock_errors.call_count, 2)
- self.assertEqual(response, {'run_id': '1'})
+ assert mock_errors.call_count == 2
+ assert response == {'run_id': '1'}
@mock.patch('airflow.providers.databricks.hooks.databricks.sleep')
def test_do_api_call_waits_between_retries(self, mock_sleep):
@@ -244,10 +245,10 @@ def test_do_api_call_waits_between_retries(self, mock_sleep):
mock_sleep.reset_mock()
setup_mock_requests(mock_requests, exception)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
- self.assertEqual(len(mock_sleep.mock_calls), self.hook.retry_limit - 1)
+ assert len(mock_sleep.mock_calls) == self.hook.retry_limit - 1
calls = [mock.call(retry_delay), mock.call(retry_delay)]
mock_sleep.assert_has_calls(calls)
@@ -257,7 +258,7 @@ def test_do_api_call_patch(self, mock_requests):
data = {'cluster_name': 'new_name'}
patched_cluster_name = self.hook._do_api_call(('PATCH', 'api/2.0/jobs/runs/submit'), data)
- self.assertEqual(patched_cluster_name['cluster_name'], 'new_name')
+ assert patched_cluster_name['cluster_name'] == 'new_name'
mock_requests.patch.assert_called_once_with(
submit_run_endpoint(HOST),
json={'cluster_name': 'new_name'},
@@ -273,7 +274,7 @@ def test_submit_run(self, mock_requests):
data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER}
run_id = self.hook.submit_run(data)
- self.assertEqual(run_id, '1')
+ assert run_id == '1'
mock_requests.post.assert_called_once_with(
submit_run_endpoint(HOST),
json={
@@ -292,7 +293,7 @@ def test_spark_python_submit_run(self, mock_requests):
data = {'spark_python_task': SPARK_PYTHON_TASK, 'new_cluster': NEW_CLUSTER}
run_id = self.hook.submit_run(data)
- self.assertEqual(run_id, '1')
+ assert run_id == '1'
mock_requests.post.assert_called_once_with(
submit_run_endpoint(HOST),
json={
@@ -314,7 +315,7 @@ def test_run_now(self, mock_requests):
data = {'notebook_params': NOTEBOOK_PARAMS, 'jar_params': JAR_PARAMS, 'job_id': JOB_ID}
run_id = self.hook.run_now(data)
- self.assertEqual(run_id, '1')
+ assert run_id == '1'
mock_requests.post.assert_called_once_with(
run_now_endpoint(HOST),
@@ -331,7 +332,7 @@ def test_get_run_page_url(self, mock_requests):
run_page_url = self.hook.get_run_page_url(RUN_ID)
- self.assertEqual(run_page_url, RUN_PAGE_URL)
+ assert run_page_url == RUN_PAGE_URL
mock_requests.get.assert_called_once_with(
get_run_endpoint(HOST),
json=None,
@@ -347,7 +348,7 @@ def test_get_job_id(self, mock_requests):
job_id = self.hook.get_job_id(RUN_ID)
- self.assertEqual(job_id, JOB_ID)
+ assert job_id == JOB_ID
mock_requests.get.assert_called_once_with(
get_run_endpoint(HOST),
json=None,
@@ -363,7 +364,7 @@ def test_get_run_state(self, mock_requests):
run_state = self.hook.get_run_state(RUN_ID)
- self.assertEqual(run_state, RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE))
+ assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE)
mock_requests.get.assert_called_once_with(
get_run_endpoint(HOST),
json=None,
@@ -504,10 +505,10 @@ def test_submit_run(self, mock_requests):
data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER}
run_id = self.hook.submit_run(data)
- self.assertEqual(run_id, '1')
+ assert run_id == '1'
args = mock_requests.post.call_args
kwargs = args[1]
- self.assertEqual(kwargs['auth'].token, TOKEN)
+ assert kwargs['auth'].token == TOKEN
class TestDatabricksHookTokenWhenNoHostIsProvidedInExtra(TestDatabricksHookToken):
@@ -526,19 +527,19 @@ def test_is_terminal_true(self):
terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']
for state in terminal_states:
run_state = RunState(state, '', '')
- self.assertTrue(run_state.is_terminal)
+ assert run_state.is_terminal
def test_is_terminal_false(self):
non_terminal_states = ['PENDING', 'RUNNING', 'TERMINATING']
for state in non_terminal_states:
run_state = RunState(state, '', '')
- self.assertFalse(run_state.is_terminal)
+ assert not run_state.is_terminal
def test_is_terminal_with_nonexistent_life_cycle_state(self):
run_state = RunState('blah', '', '')
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
run_state.is_terminal
def test_is_successful(self):
run_state = RunState('TERMINATED', 'SUCCESS', '')
- self.assertTrue(run_state.is_successful)
+ assert run_state.is_successful
diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py
index 8f43ad5391b65..e920a7ad96027 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -20,6 +20,8 @@
from datetime import datetime
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import DAG
from airflow.providers.databricks.hooks.databricks import RunState
@@ -70,7 +72,7 @@ def test_deep_string_coerce(self):
'test_list': ['1', '1.0', 'a', 'b'],
'test_tuple': ['1', '1.0', 'a', 'b'],
}
- self.assertDictEqual(databricks_operator._deep_string_coerce(test_json), expected)
+ assert databricks_operator._deep_string_coerce(test_json) == expected
class TestDatabricksSubmitRunOperator(unittest.TestCase):
@@ -85,7 +87,7 @@ def test_init_with_notebook_task_named_parameters(self):
{'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID}
)
- self.assertDictEqual(expected, op.json)
+ assert expected == op.json
def test_init_with_spark_python_task_named_parameters(self):
"""
@@ -98,7 +100,7 @@ def test_init_with_spark_python_task_named_parameters(self):
{'new_cluster': NEW_CLUSTER, 'spark_python_task': SPARK_PYTHON_TASK, 'run_name': TASK_ID}
)
- self.assertDictEqual(expected, op.json)
+ assert expected == op.json
def test_init_with_spark_submit_task_named_parameters(self):
"""
@@ -111,7 +113,7 @@ def test_init_with_spark_submit_task_named_parameters(self):
{'new_cluster': NEW_CLUSTER, 'spark_submit_task': SPARK_SUBMIT_TASK, 'run_name': TASK_ID}
)
- self.assertDictEqual(expected, op.json)
+ assert expected == op.json
def test_init_with_json(self):
"""
@@ -122,7 +124,7 @@ def test_init_with_json(self):
expected = databricks_operator._deep_string_coerce(
{'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID}
)
- self.assertDictEqual(expected, op.json)
+ assert expected == op.json
def test_init_with_specified_run_name(self):
"""
@@ -133,7 +135,7 @@ def test_init_with_specified_run_name(self):
expected = databricks_operator._deep_string_coerce(
{'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': RUN_NAME}
)
- self.assertDictEqual(expected, op.json)
+ assert expected == op.json
def test_init_with_merging(self):
"""
@@ -154,7 +156,7 @@ def test_init_with_merging(self):
'run_name': TASK_ID,
}
)
- self.assertDictEqual(expected, op.json)
+ assert expected == op.json
def test_init_with_templating(self):
json = {
@@ -171,7 +173,7 @@ def test_init_with_templating(self):
'run_name': TASK_ID,
}
)
- self.assertDictEqual(expected, op.json)
+ assert expected == op.json
def test_init_with_bad_type(self):
json = {'test': datetime.now()}
@@ -180,7 +182,7 @@ def test_init_with_bad_type(self):
r'Type \<(type|class) \'datetime.datetime\'\> used '
+ r'for parameter json\[test\] is not a number or a string'
)
- with self.assertRaisesRegex(AirflowException, exception_message):
+ with pytest.raises(AirflowException, match=exception_message):
DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
@@ -209,7 +211,7 @@ def test_exec_success(self, db_mock_class):
db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
- self.assertEqual(RUN_ID, op.run_id)
+ assert RUN_ID == op.run_id
@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
def test_exec_failure(self, db_mock_class):
@@ -225,7 +227,7 @@ def test_exec_failure(self, db_mock_class):
db_mock.submit_run.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '')
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
op.execute(None)
expected = databricks_operator._deep_string_coerce(
@@ -241,7 +243,7 @@ def test_exec_failure(self, db_mock_class):
db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
- self.assertEqual(RUN_ID, op.run_id)
+ assert RUN_ID == op.run_id
@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
def test_on_kill(self, db_mock_class):
@@ -266,7 +268,7 @@ def test_init_with_named_parameters(self):
op = DatabricksRunNowOperator(job_id=JOB_ID, task_id=TASK_ID)
expected = databricks_operator._deep_string_coerce({'job_id': 42})
- self.assertDictEqual(expected, op.json)
+ assert expected == op.json
def test_init_with_json(self):
"""
@@ -291,7 +293,7 @@ def test_init_with_json(self):
}
)
- self.assertDictEqual(expected, op.json)
+ assert expected == op.json
def test_init_with_merging(self):
"""
@@ -321,7 +323,7 @@ def test_init_with_merging(self):
}
)
- self.assertDictEqual(expected, op.json)
+ assert expected == op.json
def test_init_with_templating(self):
json = {'notebook_params': NOTEBOOK_PARAMS, 'jar_params': TEMPLATED_JAR_PARAMS}
@@ -336,7 +338,7 @@ def test_init_with_templating(self):
'job_id': JOB_ID,
}
)
- self.assertDictEqual(expected, op.json)
+ assert expected == op.json
def test_init_with_bad_type(self):
json = {'test': datetime.now()}
@@ -345,7 +347,7 @@ def test_init_with_bad_type(self):
r'Type \<(type|class) \'datetime.datetime\'\> used '
+ r'for parameter json\[test\] is not a number or a string'
)
- with self.assertRaisesRegex(AirflowException, exception_message):
+ with pytest.raises(AirflowException, match=exception_message):
DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json)
@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
@@ -376,7 +378,7 @@ def test_exec_success(self, db_mock_class):
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
- self.assertEqual(RUN_ID, op.run_id)
+ assert RUN_ID == op.run_id
@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
def test_exec_failure(self, db_mock_class):
@@ -389,7 +391,7 @@ def test_exec_failure(self, db_mock_class):
db_mock.run_now.return_value = 1
db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '')
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
op.execute(None)
expected = databricks_operator._deep_string_coerce(
@@ -406,7 +408,7 @@ def test_exec_failure(self, db_mock_class):
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
- self.assertEqual(RUN_ID, op.run_id)
+ assert RUN_ID == op.run_id
@mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
def test_on_kill(self, db_mock_class):
diff --git a/tests/providers/datadog/hooks/test_datadog.py b/tests/providers/datadog/hooks/test_datadog.py
index 56b5ce7777e06..3033e630ea9fb 100644
--- a/tests/providers/datadog/hooks/test_datadog.py
+++ b/tests/providers/datadog/hooks/test_datadog.py
@@ -20,6 +20,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.datadog.hooks.datadog import DatadogHook
@@ -60,9 +62,9 @@ def setUp(self, mock_get_connection, mock_initialize):
@mock.patch('airflow.providers.datadog.hooks.datadog.DatadogHook.get_connection')
def test_api_key_required(self, mock_get_connection, mock_initialize):
mock_get_connection.return_value = Connection()
- with self.assertRaises(AirflowException) as ctx:
+ with pytest.raises(AirflowException) as ctx:
DatadogHook()
- self.assertEqual(str(ctx.exception), 'api_key must be specified in the Datadog connection details')
+ assert str(ctx.value) == 'api_key must be specified in the Datadog connection details'
def test_validate_response_valid(self):
try:
@@ -71,7 +73,7 @@ def test_validate_response_valid(self):
self.fail('Unexpected AirflowException raised')
def test_validate_response_invalid(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.validate_response({'status': 'error'})
@mock.patch('airflow.providers.datadog.hooks.datadog.api.Metric.send')
diff --git a/tests/providers/datadog/sensors/test_datadog.py b/tests/providers/datadog/sensors/test_datadog.py
index 10dae5f31a508..51cb5a8641b0a 100644
--- a/tests/providers/datadog/sensors/test_datadog.py
+++ b/tests/providers/datadog/sensors/test_datadog.py
@@ -92,7 +92,7 @@ def test_sensor_ok(self, api1, api2):
response_check=None,
)
- self.assertTrue(sensor.poke({}))
+ assert sensor.poke({})
@patch('airflow.providers.datadog.hooks.datadog.api.Event.query')
@patch('airflow.providers.datadog.sensors.datadog.api.Event.query')
@@ -111,4 +111,4 @@ def test_sensor_fail(self, api1, api2):
response_check=None,
)
- self.assertFalse(sensor.poke({}))
+ assert not sensor.poke({})
diff --git a/tests/providers/dingding/hooks/test_dingding.py b/tests/providers/dingding/hooks/test_dingding.py
index 6b14769c85e3d..bdb5d3612b618 100644
--- a/tests/providers/dingding/hooks/test_dingding.py
+++ b/tests/providers/dingding/hooks/test_dingding.py
@@ -19,6 +19,8 @@
import json
import unittest
+import pytest
+
from airflow.models import Connection
from airflow.providers.dingding.hooks.dingding import DingdingHook
from airflow.utils import db
@@ -40,7 +42,7 @@ def setUp(self):
def test_get_endpoint_conn_id(self):
hook = DingdingHook(dingding_conn_id=self.conn_id)
endpoint = hook._get_endpoint()
- self.assertEqual('robot/send?access_token=you_token_here', endpoint)
+ assert 'robot/send?access_token=you_token_here' == endpoint
def test_build_text_message_not_remind(self):
config = {
@@ -57,7 +59,7 @@ def test_build_text_message_not_remind(self):
}
hook = DingdingHook(**config)
message = hook._build_message()
- self.assertEqual(json.dumps(expect), message)
+ assert json.dumps(expect) == message
def test_build_text_message_remind_specific(self):
config = {
@@ -74,7 +76,7 @@ def test_build_text_message_remind_specific(self):
}
hook = DingdingHook(**config)
message = hook._build_message()
- self.assertEqual(json.dumps(expect), message)
+ assert json.dumps(expect) == message
def test_build_text_message_remind_all(self):
config = {
@@ -90,7 +92,7 @@ def test_build_text_message_remind_all(self):
}
hook = DingdingHook(**config)
message = hook._build_message()
- self.assertEqual(json.dumps(expect), message)
+ assert json.dumps(expect) == message
def test_build_markdown_message_remind_specific(self):
msg = {
@@ -112,7 +114,7 @@ def test_build_markdown_message_remind_specific(self):
}
hook = DingdingHook(**config)
message = hook._build_message()
- self.assertEqual(json.dumps(expect), message)
+ assert json.dumps(expect) == message
def test_build_markdown_message_remind_all(self):
msg = {
@@ -129,7 +131,7 @@ def test_build_markdown_message_remind_all(self):
expect = {'msgtype': 'markdown', 'markdown': msg, 'at': {'atMobiles': None, 'isAtAll': True}}
hook = DingdingHook(**config)
message = hook._build_message()
- self.assertEqual(json.dumps(expect), message)
+ assert json.dumps(expect) == message
def test_build_link_message(self):
msg = {
@@ -142,7 +144,7 @@ def test_build_link_message(self):
expect = {'msgtype': 'link', 'link': msg}
hook = DingdingHook(**config)
message = hook._build_message()
- self.assertEqual(json.dumps(expect), message)
+ assert json.dumps(expect) == message
def test_build_single_action_card_message(self):
msg = {
@@ -159,7 +161,7 @@ def test_build_single_action_card_message(self):
expect = {'msgtype': 'actionCard', 'actionCard': msg}
hook = DingdingHook(**config)
message = hook._build_message()
- self.assertEqual(json.dumps(expect), message)
+ assert json.dumps(expect) == message
def test_build_multi_action_card_message(self):
msg = {
@@ -178,7 +180,7 @@ def test_build_multi_action_card_message(self):
expect = {'msgtype': 'actionCard', 'actionCard': msg}
hook = DingdingHook(**config)
message = hook._build_message()
- self.assertEqual(json.dumps(expect), message)
+ assert json.dumps(expect) == message
def test_build_feed_card_message(self):
msg = {
@@ -204,7 +206,7 @@ def test_build_feed_card_message(self):
expect = {'msgtype': 'feedCard', 'feedCard': msg}
hook = DingdingHook(**config)
message = hook._build_message()
- self.assertEqual(json.dumps(expect), message)
+ assert json.dumps(expect) == message
def test_send_not_support_type(self):
config = {
@@ -213,4 +215,5 @@ def test_send_not_support_type(self):
'message': 'Airflow dingding text message remind no one',
}
hook = DingdingHook(**config)
- self.assertRaises(ValueError, hook.send)
+ with pytest.raises(ValueError):
+ hook.send()
diff --git a/tests/providers/dingding/operators/test_dingding.py b/tests/providers/dingding/operators/test_dingding.py
index f7ed13a775d43..3635006c58296 100644
--- a/tests/providers/dingding/operators/test_dingding.py
+++ b/tests/providers/dingding/operators/test_dingding.py
@@ -43,12 +43,12 @@ def setUp(self):
def test_execute(self, mock_hook):
operator = DingdingOperator(task_id='dingding_task', dag=self.dag, **self._config)
- self.assertIsNotNone(operator)
- self.assertEqual(self._config['dingding_conn_id'], operator.dingding_conn_id)
- self.assertEqual(self._config['message_type'], operator.message_type)
- self.assertEqual(self._config['message'], operator.message)
- self.assertEqual(self._config['at_mobiles'], operator.at_mobiles)
- self.assertEqual(self._config['at_all'], operator.at_all)
+ assert operator is not None
+ assert self._config['dingding_conn_id'] == operator.dingding_conn_id
+ assert self._config['message_type'] == operator.message_type
+ assert self._config['message'] == operator.message
+ assert self._config['at_mobiles'] == operator.at_mobiles
+ assert self._config['at_all'] == operator.at_all
operator.execute(None)
mock_hook.assert_called_once_with(
diff --git a/tests/providers/discord/hooks/test_discord_webhook.py b/tests/providers/discord/hooks/test_discord_webhook.py
index f28a10a10473f..6d41c576a18f8 100644
--- a/tests/providers/discord/hooks/test_discord_webhook.py
+++ b/tests/providers/discord/hooks/test_discord_webhook.py
@@ -19,6 +19,8 @@
import json
import unittest
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.discord.hooks.discord_webhook import DiscordWebhookHook
@@ -65,7 +67,7 @@ def test_get_webhook_endpoint_manual_token(self):
webhook_endpoint = hook._get_webhook_endpoint(None, provided_endpoint)
# Then
- self.assertEqual(webhook_endpoint, provided_endpoint)
+ assert webhook_endpoint == provided_endpoint
def test_get_webhook_endpoint_invalid_url(self):
# Given
@@ -73,7 +75,7 @@ def test_get_webhook_endpoint_invalid_url(self):
# When/Then
expected_message = 'Expected Discord webhook endpoint in the form of'
- with self.assertRaisesRegex(AirflowException, expected_message):
+ with pytest.raises(AirflowException, match=expected_message):
DiscordWebhookHook(webhook_endpoint=provided_endpoint)
def test_get_webhook_endpoint_conn_id(self):
@@ -86,7 +88,7 @@ def test_get_webhook_endpoint_conn_id(self):
webhook_endpoint = hook._get_webhook_endpoint(conn_id, None)
# Then
- self.assertEqual(webhook_endpoint, expected_webhook_endpoint)
+ assert webhook_endpoint == expected_webhook_endpoint
def test_build_discord_payload(self):
# Given
@@ -96,7 +98,7 @@ def test_build_discord_payload(self):
payload = hook._build_discord_payload()
# Then
- self.assertEqual(self.expected_payload, payload)
+ assert self.expected_payload == payload
def test_build_discord_payload_message_length(self):
# Given
@@ -107,5 +109,5 @@ def test_build_discord_payload_message_length(self):
# When/Then
expected_message = 'Discord message length must be 2000 or fewer characters'
- with self.assertRaisesRegex(AirflowException, expected_message):
+ with pytest.raises(AirflowException, match=expected_message):
hook._build_discord_payload()
diff --git a/tests/providers/discord/operators/test_discord_webhook.py b/tests/providers/discord/operators/test_discord_webhook.py
index 8cf3b64bdcec1..bacefd4276a84 100644
--- a/tests/providers/discord/operators/test_discord_webhook.py
+++ b/tests/providers/discord/operators/test_discord_webhook.py
@@ -43,10 +43,10 @@ def setUp(self):
def test_execute(self):
operator = DiscordWebhookOperator(task_id='discord_webhook_task', dag=self.dag, **self._config)
- self.assertEqual(self._config['http_conn_id'], operator.http_conn_id)
- self.assertEqual(self._config['webhook_endpoint'], operator.webhook_endpoint)
- self.assertEqual(self._config['message'], operator.message)
- self.assertEqual(self._config['username'], operator.username)
- self.assertEqual(self._config['avatar_url'], operator.avatar_url)
- self.assertEqual(self._config['tts'], operator.tts)
- self.assertEqual(self._config['proxy'], operator.proxy)
+ assert self._config['http_conn_id'] == operator.http_conn_id
+ assert self._config['webhook_endpoint'] == operator.webhook_endpoint
+ assert self._config['message'] == operator.message
+ assert self._config['username'] == operator.username
+ assert self._config['avatar_url'] == operator.avatar_url
+ assert self._config['tts'] == operator.tts
+ assert self._config['proxy'] == operator.proxy
diff --git a/tests/providers/docker/hooks/test_docker.py b/tests/providers/docker/hooks/test_docker.py
index 49e810b3d2859..da4995c24d6da 100644
--- a/tests/providers/docker/hooks/test_docker.py
+++ b/tests/providers/docker/hooks/test_docker.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.utils import db
@@ -54,11 +56,11 @@ def setUp(self):
)
def test_init_fails_when_no_base_url_given(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
DockerHook(docker_conn_id='docker_default', version='auto', tls=None)
def test_init_fails_when_no_api_version_given(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
DockerHook(docker_conn_id='docker_default', base_url='unix://var/run/docker.sock', tls=None)
def test_get_conn_override_defaults(self, docker_client_mock):
@@ -79,7 +81,7 @@ def test_get_conn_with_standard_config(self, _):
docker_conn_id='docker_default', base_url='unix://var/run/docker.sock', version='auto'
)
client = hook.get_conn()
- self.assertIsNotNone(client)
+ assert client is not None
except Exception: # pylint: disable=broad-except
self.fail('Could not get connection from Airflow')
@@ -89,7 +91,7 @@ def test_get_conn_with_extra_config(self, _):
docker_conn_id='docker_with_extras', base_url='unix://var/run/docker.sock', version='auto'
)
client = hook.get_conn()
- self.assertIsNotNone(client)
+ assert client is not None
except Exception: # pylint: disable=broad-except
self.fail('Could not get connection from Airflow')
@@ -129,7 +131,7 @@ def test_conn_with_broken_config_missing_username_fails(self, _):
extra='{"email": "some@example.com"}',
)
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
DockerHook(
docker_conn_id='docker_without_username',
base_url='unix://var/run/docker.sock',
@@ -142,7 +144,7 @@ def test_conn_with_broken_config_missing_host_fails(self, _):
conn_id='docker_without_host', conn_type='docker', login='some_user', password='some_p4$$w0rd'
)
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
DockerHook(
docker_conn_id='docker_without_host', base_url='unix://var/run/docker.sock', version='auto'
)
diff --git a/tests/providers/docker/operators/test_docker.py b/tests/providers/docker/operators/test_docker.py
index 631d42be2df54..0a2f8383e825f 100644
--- a/tests/providers/docker/operators/test_docker.py
+++ b/tests/providers/docker/operators/test_docker.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
try:
@@ -106,19 +108,17 @@ def test_execute(self):
)
self.client_mock.pull.assert_called_once_with('ubuntu:latest', stream=True, decode=True)
self.client_mock.wait.assert_called_once_with('some_id')
- self.assertEqual(
- operator.cli.pull('ubuntu:latest', stream=True, decode=True), self.client_mock.pull.return_value
+ assert (
+ operator.cli.pull('ubuntu:latest', stream=True, decode=True) == self.client_mock.pull.return_value
)
def test_private_environment_is_private(self):
operator = DockerOperator(
private_environment={'PRIVATE': 'MESSAGE'}, image='ubuntu:latest', task_id='unittest'
)
- self.assertEqual(
- operator._private_environment,
- {'PRIVATE': 'MESSAGE'},
- "To keep this private, it must be an underscored attribute.",
- )
+ assert operator._private_environment == {
+ 'PRIVATE': 'MESSAGE'
+ }, "To keep this private, it must be an underscored attribute."
@mock.patch('airflow.providers.docker.operators.docker.tls.TLSConfig')
def test_execute_tls(self, tls_class_mock):
@@ -164,7 +164,7 @@ def test_execute_unicode_logs(self):
def test_execute_container_fails(self):
self.client_mock.wait.return_value = {"StatusCode": 1}
operator = DockerOperator(image='ubuntu', owner='unittest', task_id='unittest')
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
operator.execute(None)
@staticmethod
@@ -191,7 +191,7 @@ def test_execute_no_docker_conn_id_no_hook(self):
)
operator.execute(None)
- self.assertEqual(operator.get_hook.call_count, 0, 'Hook called though no docker_conn_id configured')
+ assert operator.get_hook.call_count == 0, 'Hook called though no docker_conn_id configured'
@mock.patch('airflow.providers.docker.operators.docker.DockerHook')
def test_execute_with_docker_conn_id_use_hook(self, hook_class_mock):
@@ -210,13 +210,9 @@ def test_execute_with_docker_conn_id_use_hook(self, hook_class_mock):
operator.execute(None)
- self.assertEqual(
- self.client_class_mock.call_count, 0, 'Client was called on the operator instead of the hook'
- )
- self.assertEqual(
- hook_class_mock.call_count, 1, 'Hook was not called although docker_conn_id configured'
- )
- self.assertEqual(self.client_mock.pull.call_count, 1, 'Image was not pulled using operator client')
+ assert self.client_class_mock.call_count == 0, 'Client was called on the operator instead of the hook'
+ assert hook_class_mock.call_count == 1, 'Hook was not called although docker_conn_id configured'
+ assert self.client_mock.pull.call_count == 1, 'Image was not pulled using operator client'
def test_execute_xcom_behavior(self):
self.client_mock.pull.return_value = [b'{"status":"pull log"}']
@@ -244,23 +240,14 @@ def test_execute_xcom_behavior(self):
xcom_push_result = xcom_push_operator.execute(None)
no_xcom_push_result = no_xcom_push_operator.execute(None)
- self.assertEqual(xcom_push_result, b'container log')
- self.assertIs(no_xcom_push_result, None)
+ assert xcom_push_result == b'container log'
+ assert no_xcom_push_result is None
def test_extra_hosts(self):
hosts_obj = mock.Mock()
operator = DockerOperator(task_id='test', image='test', extra_hosts=hosts_obj)
operator.execute(None)
self.client_mock.create_container.assert_called_once()
- self.assertIn(
- 'host_config',
- self.client_mock.create_container.call_args[1],
- )
- self.assertIn(
- 'extra_hosts',
- self.client_mock.create_host_config.call_args[1],
- )
- self.assertIs(
- hosts_obj,
- self.client_mock.create_host_config.call_args[1]['extra_hosts'],
- )
+ assert 'host_config' in self.client_mock.create_container.call_args[1]
+ assert 'extra_hosts' in self.client_mock.create_host_config.call_args[1]
+ assert hosts_obj is self.client_mock.create_host_config.call_args[1]['extra_hosts']
diff --git a/tests/providers/docker/operators/test_docker_swarm.py b/tests/providers/docker/operators/test_docker_swarm.py
index a3208ef00138b..9470752e5ee36 100644
--- a/tests/providers/docker/operators/test_docker_swarm.py
+++ b/tests/providers/docker/operators/test_docker_swarm.py
@@ -19,6 +19,7 @@
import unittest
from unittest import mock
+import pytest
import requests
from docker import APIClient
@@ -90,13 +91,11 @@ def _client_service_logs_effect():
)
csargs, cskwargs = client_mock.create_service.call_args_list[0]
- self.assertEqual(
- len(csargs), 1, 'create_service called with different number of arguments than expected'
- )
- self.assertEqual(csargs, (mock_obj,))
- self.assertEqual(cskwargs['labels'], {'name': 'airflow__adhoc_airflow__unittest'})
- self.assertTrue(cskwargs['name'].startswith('airflow-'))
- self.assertEqual(client_mock.tasks.call_count, 5)
+ assert len(csargs) == 1, 'create_service called with different number of arguments than expected'
+ assert csargs == (mock_obj,)
+ assert cskwargs['labels'] == {'name': 'airflow__adhoc_airflow__unittest'}
+ assert cskwargs['name'].startswith('airflow-')
+ assert client_mock.tasks.call_count == 5
client_mock.remove_service.assert_called_once_with('some_id')
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
@@ -120,11 +119,9 @@ def test_no_auto_remove(self, types_mock, client_class_mock):
operator = DockerSwarmOperator(image='', auto_remove=False, task_id='unittest', enable_logging=False)
operator.execute(None)
- self.assertEqual(
- client_mock.remove_service.call_count,
- 0,
- 'Docker service being removed even when `auto_remove` set to `False`',
- )
+ assert (
+ client_mock.remove_service.call_count == 0
+ ), 'Docker service being removed even when `auto_remove` set to `False`'
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
@mock.patch('airflow.providers.docker.operators.docker_swarm.types')
@@ -146,9 +143,9 @@ def test_failed_service_raises_error(self, types_mock, client_class_mock):
operator = DockerSwarmOperator(image='', auto_remove=False, task_id='unittest', enable_logging=False)
msg = "Service failed: {'ID': 'some_id'}"
- with self.assertRaises(AirflowException) as error:
+ with pytest.raises(AirflowException) as ctx:
operator.execute(None)
- self.assertEqual(str(error.exception), msg)
+ assert str(ctx.value) == msg
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
@mock.patch('airflow.providers.docker.operators.docker_swarm.types')
diff --git a/tests/providers/elasticsearch/hooks/test_elasticsearch.py b/tests/providers/elasticsearch/hooks/test_elasticsearch.py
index 854f149d81a47..b96f64eadb9b6 100644
--- a/tests/providers/elasticsearch/hooks/test_elasticsearch.py
+++ b/tests/providers/elasticsearch/hooks/test_elasticsearch.py
@@ -66,7 +66,7 @@ def test_get_first_record(self):
result_sets = [('row1',), ('row2',)]
self.cur.fetchone.return_value = result_sets[0]
- self.assertEqual(result_sets[0], self.db_hook.get_first(statement))
+ assert result_sets[0] == self.db_hook.get_first(statement)
self.conn.close.assert_called_once_with()
self.cur.close.assert_called_once_with()
self.cur.execute.assert_called_once_with(statement)
@@ -76,7 +76,7 @@ def test_get_records(self):
result_sets = [('row1',), ('row2',)]
self.cur.fetchall.return_value = result_sets
- self.assertEqual(result_sets, self.db_hook.get_records(statement))
+ assert result_sets == self.db_hook.get_records(statement)
self.conn.close.assert_called_once_with()
self.cur.close.assert_called_once_with()
self.cur.execute.assert_called_once_with(statement)
@@ -89,9 +89,9 @@ def test_get_pandas_df(self):
self.cur.fetchall.return_value = result_sets
df = self.db_hook.get_pandas_df(statement)
- self.assertEqual(column, df.columns[0])
+ assert column == df.columns[0]
- self.assertEqual(result_sets[0][0], df.values.tolist()[0][0])
- self.assertEqual(result_sets[1][0], df.values.tolist()[1][0])
+ assert result_sets[0][0] == df.values.tolist()[0][0]
+ assert result_sets[1][0] == df.values.tolist()[1][0]
self.cur.execute.assert_called_once_with(statement)
diff --git a/tests/providers/elasticsearch/log/test_es_task_handler.py b/tests/providers/elasticsearch/log/test_es_task_handler.py
index 3531c09611b70..a00b9f05745e8 100644
--- a/tests/providers/elasticsearch/log/test_es_task_handler.py
+++ b/tests/providers/elasticsearch/log/test_es_task_handler.py
@@ -85,7 +85,7 @@ def tearDown(self):
shutil.rmtree(self.local_log_location.split(os.path.sep)[0], ignore_errors=True)
def test_client(self):
- self.assertIsInstance(self.es_task_handler.client, elasticsearch.Elasticsearch)
+ assert isinstance(self.es_task_handler.client, elasticsearch.Elasticsearch)
def test_client_with_config(self):
es_conf = dict(conf.getsection("elasticsearch_configs"))
@@ -93,7 +93,7 @@ def test_client_with_config(self):
"use_ssl": False,
"verify_certs": True,
}
- self.assertDictEqual(es_conf, expected_dict)
+ assert es_conf == expected_dict
# ensure creating with configs does not fail
ElasticsearchTaskHandler(
self.local_log_location,
@@ -112,13 +112,13 @@ def test_read(self):
self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False}
)
- self.assertEqual(1, len(logs))
- self.assertEqual(len(logs), len(metadatas))
- self.assertEqual(len(logs[0]), 1)
- self.assertEqual(self.test_message, logs[0][0][-1])
- self.assertFalse(metadatas[0]['end_of_log'])
- self.assertEqual('1', metadatas[0]['offset'])
- self.assertTrue(timezone.parse(metadatas[0]['last_log_timestamp']) > ts)
+ assert 1 == len(logs)
+ assert len(logs) == len(metadatas)
+ assert len(logs[0]) == 1
+ assert self.test_message == logs[0][0][-1]
+ assert not metadatas[0]['end_of_log']
+ assert '1' == metadatas[0]['offset']
+ assert timezone.parse(metadatas[0]['last_log_timestamp']) > ts
def test_read_with_match_phrase_query(self):
similar_log_id = '{task_id}-{dag_id}-2016-01-01T00:00:00+00:00-1'.format(
@@ -133,23 +133,23 @@ def test_read_with_match_phrase_query(self):
logs, metadatas = self.es_task_handler.read(
self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False}
)
- self.assertEqual(1, len(logs))
- self.assertEqual(len(logs), len(metadatas))
- self.assertEqual(self.test_message, logs[0][0][-1])
- self.assertNotEqual(another_test_message, logs[0])
+ assert 1 == len(logs)
+ assert len(logs) == len(metadatas)
+ assert self.test_message == logs[0][0][-1]
+ assert another_test_message != logs[0]
- self.assertFalse(metadatas[0]['end_of_log'])
- self.assertEqual('1', metadatas[0]['offset'])
- self.assertTrue(timezone.parse(metadatas[0]['last_log_timestamp']) > ts)
+ assert not metadatas[0]['end_of_log']
+ assert '1' == metadatas[0]['offset']
+ assert timezone.parse(metadatas[0]['last_log_timestamp']) > ts
def test_read_with_none_metadata(self):
logs, metadatas = self.es_task_handler.read(self.ti, 1)
- self.assertEqual(1, len(logs))
- self.assertEqual(len(logs), len(metadatas))
- self.assertEqual(self.test_message, logs[0][0][-1])
- self.assertFalse(metadatas[0]['end_of_log'])
- self.assertEqual('1', metadatas[0]['offset'])
- self.assertTrue(timezone.parse(metadatas[0]['last_log_timestamp']) < pendulum.now())
+ assert 1 == len(logs)
+ assert len(logs) == len(metadatas)
+ assert self.test_message == logs[0][0][-1]
+ assert not metadatas[0]['end_of_log']
+ assert '1' == metadatas[0]['offset']
+ assert timezone.parse(metadatas[0]['last_log_timestamp']) < pendulum.now()
def test_read_nonexistent_log(self):
ts = pendulum.now()
@@ -160,39 +160,39 @@ def test_read_nonexistent_log(self):
logs, metadatas = self.es_task_handler.read(
self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False}
)
- self.assertEqual(1, len(logs))
- self.assertEqual(len(logs), len(metadatas))
- self.assertEqual([[]], logs)
- self.assertFalse(metadatas[0]['end_of_log'])
- self.assertEqual('0', metadatas[0]['offset'])
+ assert 1 == len(logs)
+ assert len(logs) == len(metadatas)
+ assert [[]] == logs
+ assert not metadatas[0]['end_of_log']
+ assert '0' == metadatas[0]['offset']
# last_log_timestamp won't change if no log lines read.
- self.assertTrue(timezone.parse(metadatas[0]['last_log_timestamp']) == ts)
+ assert timezone.parse(metadatas[0]['last_log_timestamp']) == ts
def test_read_with_empty_metadata(self):
ts = pendulum.now()
logs, metadatas = self.es_task_handler.read(self.ti, 1, {})
- self.assertEqual(1, len(logs))
- self.assertEqual(len(logs), len(metadatas))
- self.assertEqual(self.test_message, logs[0][0][-1])
- self.assertFalse(metadatas[0]['end_of_log'])
+ assert 1 == len(logs)
+ assert len(logs) == len(metadatas)
+ assert self.test_message == logs[0][0][-1]
+ assert not metadatas[0]['end_of_log']
# offset should be initialized to 0 if not provided.
- self.assertEqual('1', metadatas[0]['offset'])
+ assert '1' == metadatas[0]['offset']
# last_log_timestamp will be initialized using log reading time
# if not last_log_timestamp is provided.
- self.assertTrue(timezone.parse(metadatas[0]['last_log_timestamp']) > ts)
+ assert timezone.parse(metadatas[0]['last_log_timestamp']) > ts
# case where offset is missing but metadata not empty.
self.es.delete(index=self.index_name, doc_type=self.doc_type, id=1)
logs, metadatas = self.es_task_handler.read(self.ti, 1, {'end_of_log': False})
- self.assertEqual(1, len(logs))
- self.assertEqual(len(logs), len(metadatas))
- self.assertEqual([[]], logs)
- self.assertFalse(metadatas[0]['end_of_log'])
+ assert 1 == len(logs)
+ assert len(logs) == len(metadatas)
+ assert [[]] == logs
+ assert not metadatas[0]['end_of_log']
# offset should be initialized to 0 if not provided.
- self.assertEqual('0', metadatas[0]['offset'])
+ assert '0' == metadatas[0]['offset']
# last_log_timestamp will be initialized using log reading time
# if not last_log_timestamp is provided.
- self.assertTrue(timezone.parse(metadatas[0]['last_log_timestamp']) > ts)
+ assert timezone.parse(metadatas[0]['last_log_timestamp']) > ts
def test_read_timeout(self):
ts = pendulum.now().subtract(minutes=5)
@@ -201,13 +201,13 @@ def test_read_timeout(self):
logs, metadatas = self.es_task_handler.read(
self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False}
)
- self.assertEqual(1, len(logs))
- self.assertEqual(len(logs), len(metadatas))
- self.assertEqual([[]], logs)
- self.assertTrue(metadatas[0]['end_of_log'])
+ assert 1 == len(logs)
+ assert len(logs) == len(metadatas)
+ assert [[]] == logs
+ assert metadatas[0]['end_of_log']
# offset should be initialized to 0 if not provided.
- self.assertEqual('0', metadatas[0]['offset'])
- self.assertTrue(timezone.parse(metadatas[0]['last_log_timestamp']) == ts)
+ assert '0' == metadatas[0]['offset']
+ assert timezone.parse(metadatas[0]['last_log_timestamp']) == ts
def test_read_as_download_logs(self):
ts = pendulum.now()
@@ -216,14 +216,14 @@ def test_read_as_download_logs(self):
1,
{'offset': 0, 'last_log_timestamp': str(ts), 'download_logs': True, 'end_of_log': False},
)
- self.assertEqual(1, len(logs))
- self.assertEqual(len(logs), len(metadatas))
- self.assertEqual(len(logs[0]), 1)
- self.assertEqual(self.test_message, logs[0][0][-1])
- self.assertFalse(metadatas[0]['end_of_log'])
- self.assertTrue(metadatas[0]['download_logs'])
- self.assertEqual('1', metadatas[0]['offset'])
- self.assertTrue(timezone.parse(metadatas[0]['last_log_timestamp']) > ts)
+ assert 1 == len(logs)
+ assert len(logs) == len(metadatas)
+ assert len(logs[0]) == 1
+ assert self.test_message == logs[0][0][-1]
+ assert not metadatas[0]['end_of_log']
+ assert metadatas[0]['download_logs']
+ assert '1' == metadatas[0]['offset']
+ assert timezone.parse(metadatas[0]['last_log_timestamp']) > ts
def test_read_raises(self):
with mock.patch.object(self.es_task_handler.log, 'exception') as mock_exception:
@@ -232,17 +232,17 @@ def test_read_raises(self):
logs, metadatas = self.es_task_handler.read(self.ti, 1)
assert mock_exception.call_count == 1
args, kwargs = mock_exception.call_args
- self.assertIn("Could not read log with log_id:", args[0])
+ assert "Could not read log with log_id:" in args[0]
- self.assertEqual(1, len(logs))
- self.assertEqual(len(logs), len(metadatas))
- self.assertEqual([[]], logs)
- self.assertFalse(metadatas[0]['end_of_log'])
- self.assertEqual('0', metadatas[0]['offset'])
+ assert 1 == len(logs)
+ assert len(logs) == len(metadatas)
+ assert [[]] == logs
+ assert not metadatas[0]['end_of_log']
+ assert '0' == metadatas[0]['offset']
def test_set_context(self):
self.es_task_handler.set_context(self.ti)
- self.assertTrue(self.es_task_handler.mark_end_on_close)
+ assert self.es_task_handler.mark_end_on_close
def test_set_context_w_json_format_and_write_stdout(self):
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -272,9 +272,7 @@ def test_read_with_json_format(self):
logs, _ = self.es_task_handler.read(
self.ti, 1, {'offset': 0, 'last_log_timestamp': str(ts), 'end_of_log': False}
)
- self.assertEqual(
- "[2020-12-24 19:25:00,962] {taskinstance.py:851} INFO - some random stuff", logs[0][0][1]
- )
+ assert "[2020-12-24 19:25:00,962] {taskinstance.py:851} INFO - some random stuff" == logs[0][0][1]
def test_close(self):
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -289,8 +287,8 @@ def test_close(self):
# have the log uploaded but will not be stored in elasticsearch.
# so apply the strip() to log_file.read()
log_line = log_file.read().strip()
- self.assertEqual(self.end_of_log_mark.strip(), log_line)
- self.assertTrue(self.es_task_handler.closed)
+ assert self.end_of_log_mark.strip() == log_line
+ assert self.es_task_handler.closed
def test_close_no_mark_end(self):
self.ti.raw = True
@@ -299,8 +297,8 @@ def test_close_no_mark_end(self):
with open(
os.path.join(self.local_log_location, self.filename_template.format(try_number=1))
) as log_file:
- self.assertNotIn(self.end_of_log_mark, log_file.read())
- self.assertTrue(self.es_task_handler.closed)
+ assert self.end_of_log_mark not in log_file.read()
+ assert self.es_task_handler.closed
def test_close_closed(self):
self.es_task_handler.closed = True
@@ -309,7 +307,7 @@ def test_close_closed(self):
with open(
os.path.join(self.local_log_location, self.filename_template.format(try_number=1))
) as log_file:
- self.assertEqual(0, len(log_file.read()))
+ assert 0 == len(log_file.read())
def test_close_with_no_handler(self):
self.es_task_handler.set_context(self.ti)
@@ -318,8 +316,8 @@ def test_close_with_no_handler(self):
with open(
os.path.join(self.local_log_location, self.filename_template.format(try_number=1))
) as log_file:
- self.assertEqual(0, len(log_file.read()))
- self.assertTrue(self.es_task_handler.closed)
+ assert 0 == len(log_file.read())
+ assert self.es_task_handler.closed
def test_close_with_no_stream(self):
self.es_task_handler.set_context(self.ti)
@@ -328,8 +326,8 @@ def test_close_with_no_stream(self):
with open(
os.path.join(self.local_log_location, self.filename_template.format(try_number=1))
) as log_file:
- self.assertIn(self.end_of_log_mark, log_file.read())
- self.assertTrue(self.es_task_handler.closed)
+ assert self.end_of_log_mark in log_file.read()
+ assert self.es_task_handler.closed
self.es_task_handler.set_context(self.ti)
self.es_task_handler.handler.stream.close()
@@ -337,8 +335,8 @@ def test_close_with_no_stream(self):
with open(
os.path.join(self.local_log_location, self.filename_template.format(try_number=1))
) as log_file:
- self.assertIn(self.end_of_log_mark, log_file.read())
- self.assertTrue(self.es_task_handler.closed)
+ assert self.end_of_log_mark in log_file.read()
+ assert self.es_task_handler.closed
def test_render_log_id(self):
expected_log_id = (
@@ -346,7 +344,7 @@ def test_render_log_id(self):
'task_for_testing_file_log_handler-2016-01-01T00:00:00+00:00-1'
)
log_id = self.es_task_handler._render_log_id(self.ti, 1)
- self.assertEqual(expected_log_id, log_id)
+ assert expected_log_id == log_id
# Switch to use jinja template.
self.es_task_handler = ElasticsearchTaskHandler(
@@ -359,11 +357,11 @@ def test_render_log_id(self):
self.json_fields,
)
log_id = self.es_task_handler._render_log_id(self.ti, 1)
- self.assertEqual(expected_log_id, log_id)
+ assert expected_log_id == log_id
def test_clean_execution_date(self):
clean_execution_date = self.es_task_handler._clean_execution_date(datetime(2016, 7, 8, 9, 10, 11, 12))
- self.assertEqual('2016_07_08T09_10_11_000012', clean_execution_date)
+ assert '2016_07_08T09_10_11_000012' == clean_execution_date
@parameterized.expand(
[
@@ -385,4 +383,4 @@ def test_get_external_log_url(self, es_frontend, expected_url):
frontend=es_frontend,
)
url = es_task_handler.get_external_log_url(self.ti, self.ti.try_number)
- self.assertEqual(expected_url, url)
+ assert expected_url == url
diff --git a/tests/providers/exasol/hooks/test_exasol.py b/tests/providers/exasol/hooks/test_exasol.py
index 4b964e102ad33..5e729ac6adc3a 100644
--- a/tests/providers/exasol/hooks/test_exasol.py
+++ b/tests/providers/exasol/hooks/test_exasol.py
@@ -21,6 +21,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow import models
from airflow.providers.exasol.hooks.exasol import ExasolHook
@@ -47,11 +49,11 @@ def test_get_conn(self, mock_pyexasol):
mock_connect = mock_pyexasol.connect
mock_connect.assert_called_once()
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['user'], 'login')
- self.assertEqual(kwargs['password'], 'password')
- self.assertEqual(kwargs['dsn'], 'host:1234')
- self.assertEqual(kwargs['schema'], 'schema')
+ assert args == ()
+ assert kwargs['user'] == 'login'
+ assert kwargs['password'] == 'password'
+ assert kwargs['dsn'] == 'host:1234'
+ assert kwargs['schema'] == 'schema'
@mock.patch('airflow.providers.exasol.hooks.exasol.pyexasol')
def test_get_conn_extra_args(self, mock_pyexasol):
@@ -60,8 +62,8 @@ def test_get_conn_extra_args(self, mock_pyexasol):
mock_connect = mock_pyexasol.connect
mock_connect.assert_called_once()
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['encryption'], True)
+ assert args == ()
+ assert kwargs['encryption'] is True
class TestExasolHook(unittest.TestCase):
@@ -90,7 +92,7 @@ def test_set_autocommit(self):
def test_get_autocommit(self):
setattr(self.conn, 'autocommit', True)
setattr(self.conn, 'attr', {'autocommit': False})
- self.assertFalse(self.db_hook.get_autocommit(self.conn))
+ assert not self.db_hook.get_autocommit(self.conn)
def test_run_without_autocommit(self):
sql = 'SQL'
@@ -124,17 +126,19 @@ def test_run_multi_queries(self):
self.conn.set_autocommit.assert_called_once_with(True)
for i in range(len(self.conn.execute.call_args_list)):
args, kwargs = self.conn.execute.call_args_list[i]
- self.assertEqual(len(args), 2)
- self.assertEqual(args[0], sql[i])
- self.assertEqual(kwargs, {})
+ assert len(args) == 2
+ assert args[0] == sql[i]
+ assert kwargs == {}
self.conn.execute.assert_called_with(sql[1], None)
self.conn.commit.assert_not_called()
def test_bulk_load(self):
- self.assertRaises(NotImplementedError, self.db_hook.bulk_load, 'table', '/tmp/file')
+ with pytest.raises(NotImplementedError):
+ self.db_hook.bulk_load('table', '/tmp/file')
def test_bulk_dump(self):
- self.assertRaises(NotImplementedError, self.db_hook.bulk_dump, 'table', '/tmp/file')
+ with pytest.raises(NotImplementedError):
+ self.db_hook.bulk_dump('table', '/tmp/file')
def test_serialize_cell(self):
- self.assertEqual('foo', self.db_hook._serialize_cell('foo', None))
+ assert 'foo' == self.db_hook._serialize_cell('foo', None)
diff --git a/tests/providers/ftp/sensors/test_ftp.py b/tests/providers/ftp/sensors/test_ftp.py
index df10d60d13bcd..14abc5b3a39c0 100644
--- a/tests/providers/ftp/sensors/test_ftp.py
+++ b/tests/providers/ftp/sensors/test_ftp.py
@@ -20,6 +20,8 @@
from ftplib import error_perm
from unittest import mock
+import pytest
+
from airflow.providers.ftp.hooks.ftp import FTPHook
from airflow.providers.ftp.sensors.ftp import FTPSensor
@@ -36,10 +38,10 @@ def test_poke(self, mock_hook):
None,
]
- self.assertFalse(op.poke(None))
- self.assertFalse(op.poke(None))
- self.assertFalse(op.poke(None))
- self.assertTrue(op.poke(None))
+ assert not op.poke(None)
+ assert not op.poke(None)
+ assert not op.poke(None)
+ assert op.poke(None)
@mock.patch('airflow.providers.ftp.sensors.ftp.FTPHook', spec=FTPHook)
def test_poke_fails_due_error(self, mock_hook):
@@ -49,10 +51,10 @@ def test_poke_fails_due_error(self, mock_hook):
"530: Login authentication failed"
)
- with self.assertRaises(error_perm) as context:
+ with pytest.raises(error_perm) as ctx:
op.execute(None)
- self.assertTrue("530" in str(context.exception))
+ assert "530" in str(ctx.value)
@mock.patch('airflow.providers.ftp.sensors.ftp.FTPHook', spec=FTPHook)
def test_poke_fail_on_transient_error(self, mock_hook):
@@ -62,10 +64,10 @@ def test_poke_fail_on_transient_error(self, mock_hook):
"434: Host unavailable"
)
- with self.assertRaises(error_perm) as context:
+ with pytest.raises(error_perm) as ctx:
op.execute(None)
- self.assertTrue("434" in str(context.exception))
+ assert "434" in str(ctx.value)
@mock.patch('airflow.providers.ftp.sensors.ftp.FTPHook', spec=FTPHook)
def test_poke_ignore_transient_error(self, mock_hook):
@@ -78,5 +80,5 @@ def test_poke_ignore_transient_error(self, mock_hook):
None,
]
- self.assertFalse(op.poke(None))
- self.assertTrue(op.poke(None))
+ assert not op.poke(None)
+ assert op.poke(None)
diff --git a/tests/providers/google/cloud/_internal_client/test_secret_manager_client.py b/tests/providers/google/cloud/_internal_client/test_secret_manager_client.py
index c842fc75cbf4d..f28b0d2393f0e 100644
--- a/tests/providers/google/cloud/_internal_client/test_secret_manager_client.py
+++ b/tests/providers/google/cloud/_internal_client/test_secret_manager_client.py
@@ -51,7 +51,7 @@ def test_get_non_existing_key(self, mock_client_info, mock_secrets_client):
secrets_client = _SecretManagerClient(credentials="credentials")
secret = secrets_client.get_secret(secret_id="missing", project_id="project_id")
mock_client.secret_version_path.assert_called_once_with("project_id", 'missing', 'latest')
- self.assertIsNone(secret)
+ assert secret is None
mock_client.access_secret_version.assert_called_once_with('full-path')
@mock.patch(INTERNAL_CLIENT_MODULE + ".SecretManagerServiceClient")
@@ -66,7 +66,7 @@ def test_get_no_permissions(self, mock_client_info, mock_secrets_client):
secrets_client = _SecretManagerClient(credentials="credentials")
secret = secrets_client.get_secret(secret_id="missing", project_id="project_id")
mock_client.secret_version_path.assert_called_once_with("project_id", 'missing', 'latest')
- self.assertIsNone(secret)
+ assert secret is None
mock_client.access_secret_version.assert_called_once_with('full-path')
@mock.patch(INTERNAL_CLIENT_MODULE + ".SecretManagerServiceClient")
@@ -82,7 +82,7 @@ def test_get_existing_key(self, mock_client_info, mock_secrets_client):
secrets_client = _SecretManagerClient(credentials="credentials")
secret = secrets_client.get_secret(secret_id="existing", project_id="project_id")
mock_client.secret_version_path.assert_called_once_with("project_id", 'existing', 'latest')
- self.assertEqual("result", secret)
+ assert "result" == secret
mock_client.access_secret_version.assert_called_once_with('full-path')
@mock.patch(INTERNAL_CLIENT_MODULE + ".SecretManagerServiceClient")
@@ -100,5 +100,5 @@ def test_get_existing_key_with_version(self, mock_client_info, mock_secrets_clie
secret_id="existing", project_id="project_id", secret_version="test-version"
)
mock_client.secret_version_path.assert_called_once_with("project_id", 'existing', 'test-version')
- self.assertEqual("result", secret)
+ assert "result" == secret
mock_client.access_secret_version.assert_called_once_with('full-path')
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py
index 840f26a5bd3f1..62999864347b3 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -19,6 +19,7 @@
import unittest
from unittest import mock
+import pytest
from google.cloud.bigquery import DEFAULT_RETRY, DatasetReference, Table, TableReference
from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem
from google.cloud.exceptions import NotFound
@@ -77,7 +78,7 @@ def test_bigquery_client_creation(self, mock_build, mock_authorize, mock_bigquer
location=self.hook.location,
num_retries=self.hook.num_retries,
)
- self.assertEqual(mock_bigquery_connection.return_value, result)
+ assert mock_bigquery_connection.return_value == result
@mock.patch("airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__")
def test_bigquery_bigquery_conn_id_deprecation_warning(
@@ -89,26 +90,26 @@ def test_bigquery_bigquery_conn_id_deprecation_warning(
"The bigquery_conn_id parameter has been deprecated. "
"You should pass the gcp_conn_id parameter."
)
- with self.assertWarns(DeprecationWarning) as warn:
+ with pytest.warns(DeprecationWarning) as warnings:
BigQueryHook(bigquery_conn_id=bigquery_conn_id)
mock_base_hook_init.assert_called_once_with(
delegate_to=None,
gcp_conn_id='bigquery conn id',
impersonation_chain=None,
)
- self.assertEqual(warning_message, str(warn.warning))
+ assert warning_message == str(warnings[0].message)
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_location_propagates_properly(self, run_with_config, _):
# TODO: this creates side effect
- self.assertIsNone(self.hook.location)
+ assert self.hook.location is None
self.hook.run_query(sql='select 1', location='US')
assert run_with_config.call_count == 1
- self.assertEqual(self.hook.location, 'US')
+ assert self.hook.location == 'US'
def test_bigquery_insert_rows_not_implemented(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.hook.insert_rows(table="table", rows=[1, 2])
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client")
@@ -166,10 +167,13 @@ def test_get_pandas_df(self, mock_read_gbq):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_invalid_schema_update_options(self, mock_get_service):
- with self.assertRaisesRegex(
+ with pytest.raises(
Exception,
- r"\['THIS IS NOT VALID'\] contains invalid schema update options.Please only use one or more of "
- r"the following options: \['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION'\]",
+ match=(
+ r"\['THIS IS NOT VALID'\] contains invalid schema update options."
+ r"Please only use one or more of the following options: "
+ r"\['ALLOW_FIELD_ADDITION', 'ALLOW_FIELD_RELAXATION'\]"
+ ),
):
self.hook.run_load(
@@ -181,9 +185,9 @@ def test_invalid_schema_update_options(self, mock_get_service):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_invalid_schema_update_and_write_disposition(self, mock_get_service):
- with self.assertRaisesRegex(
+ with pytest.raises(
Exception,
- "schema_update_options is only allowed if"
+ match="schema_update_options is only allowed if"
" write_disposition is 'WRITE_APPEND' or 'WRITE_TRUNCATE'.",
):
@@ -219,14 +223,14 @@ def test_run_query_sql_dialect_default(
):
self.hook.run_query('query')
_, kwargs = mock_insert.call_args
- self.assertIs(kwargs['configuration']['query']['useLegacySql'], True)
+ assert kwargs['configuration']['query']['useLegacySql'] is True
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_run_query_sql_dialect(self, mock_insert, _):
self.hook.run_query('query', use_legacy_sql=False)
_, kwargs = mock_insert.call_args
- self.assertIs(kwargs['configuration']['query']['useLegacySql'], False)
+ assert kwargs['configuration']['query']['useLegacySql'] is False
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
@@ -240,7 +244,7 @@ def test_run_query_sql_dialect_legacy_with_query_params(self, mock_insert, _):
]
self.hook.run_query('query', use_legacy_sql=False, query_params=params)
_, kwargs = mock_insert.call_args
- self.assertIs(kwargs['configuration']['query']['useLegacySql'], False)
+ assert kwargs['configuration']['query']['useLegacySql'] is False
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_run_query_sql_dialect_legacy_with_query_params_fails(self, _):
@@ -251,13 +255,13 @@ def test_run_query_sql_dialect_legacy_with_query_params_fails(self, _):
'parameterValue': {'value': "param_value"},
}
]
- with self.assertRaisesRegex(ValueError, "Query parameters are not allowed when using legacy SQL"):
+ with pytest.raises(ValueError, match="Query parameters are not allowed when using legacy SQL"):
self.hook.run_query('query', use_legacy_sql=True, query_params=params)
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_run_query_without_sql_fails(self, _):
- with self.assertRaisesRegex(
- TypeError, r"`BigQueryBaseCursor.run_query` missing 1 required positional argument: `sql`"
+ with pytest.raises(
+ TypeError, match=r"`BigQueryBaseCursor.run_query` missing 1 required positional argument: `sql`"
):
self.hook.run_query(sql=None)
@@ -287,8 +291,8 @@ def test_run_query_schema_update_options(
write_disposition=write_disposition,
)
_, kwargs = mock_insert.call_args
- self.assertEqual(kwargs['configuration']['query']['schemaUpdateOptions'], schema_update_options)
- self.assertEqual(kwargs['configuration']['query']['writeDisposition'], write_disposition)
+ assert kwargs['configuration']['query']['schemaUpdateOptions'] == schema_update_options
+ assert kwargs['configuration']['query']['writeDisposition'] == write_disposition
@parameterized.expand(
[
@@ -322,7 +326,7 @@ def test_run_query_schema_update_options_incorrect(
expected_regex,
mock_get_service,
):
- with self.assertRaisesRegex(ValueError, expected_regex):
+ with pytest.raises(ValueError, match=expected_regex):
self.hook.run_query(
sql='query',
destination_dataset_table='my_dataset.my_table',
@@ -341,46 +345,48 @@ def test_api_resource_configs(
):
self.hook.run_query('query', api_resource_configs={'query': {'useQueryCache': bool_val}})
_, kwargs = mock_insert.call_args
- self.assertIs(kwargs["configuration"]['query']['useQueryCache'], bool_val)
- self.assertIs(kwargs["configuration"]['query']['useLegacySql'], True)
+ assert kwargs["configuration"]['query']['useQueryCache'] is bool_val
+ assert kwargs["configuration"]['query']['useLegacySql'] is True
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_api_resource_configs_duplication_warning(self, mock_get_service):
- with self.assertRaisesRegex(
+ with pytest.raises(
ValueError,
- r"Values of useLegacySql param are duplicated\. api_resource_configs contained useLegacySql "
- r"param in `query` config and useLegacySql was also provided with arg to run_query\(\) method\. "
- r"Please remove duplicates\.",
+ match=(
+ r"Values of useLegacySql param are duplicated\. api_resource_configs "
+ r"contained useLegacySql param in `query` config and useLegacySql was "
+ r"also provided with arg to run_query\(\) method\. Please remove duplicates\."
+ ),
):
self.hook.run_query(
'query', use_legacy_sql=True, api_resource_configs={'query': {'useLegacySql': False}}
)
def test_validate_value(self):
- with self.assertRaisesRegex(
- TypeError, "case_1 argument must have a type not "
+ with pytest.raises(
+ TypeError, match="case_1 argument must have a type not "
):
_validate_value("case_1", "a", dict)
- self.assertIsNone(_validate_value("case_2", 0, int))
+ assert _validate_value("case_2", 0, int) is None
def test_duplication_check(self):
- with self.assertRaisesRegex(
+ with pytest.raises(
ValueError,
- r"Values of key_one param are duplicated. api_resource_configs contained key_one param in"
+ match=r"Values of key_one param are duplicated. api_resource_configs contained key_one param in"
r" `query` config and key_one was also provided with arg to run_query\(\) method. "
r"Please remove duplicates.",
):
key_one = True
_api_resource_configs_duplication_check("key_one", key_one, {"key_one": False})
- self.assertIsNone(_api_resource_configs_duplication_check("key_one", key_one, {"key_one": True}))
+ assert _api_resource_configs_duplication_check("key_one", key_one, {"key_one": True}) is None
def test_validate_src_fmt_configs(self):
source_format = "test_format"
valid_configs = ["test_config_known", "compatibility_val"]
backward_compatibility_configs = {"compatibility_val": "val"}
- with self.assertRaisesRegex(
- ValueError, "test_config_unknown is not a valid src_fmt_configs for type test_format."
+ with pytest.raises(
+ ValueError, match="test_config_unknown is not a valid src_fmt_configs for type test_format."
):
# This config should raise a value error.
src_fmt_configs = {"test_config_unknown": "val"}
@@ -581,7 +587,7 @@ def test_get_dataset_tables_list(self, mock_client):
mock_client.return_value.list_tables.assert_called_once_with(
dataset=dataset_reference, max_results=None
)
- self.assertEqual(table_list, result)
+ assert table_list == result
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client")
def test_poll_job_complete(self, mock_client):
@@ -670,9 +676,9 @@ def test_get_schema(self, mock_client):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_invalid_source_format(self, mock_get_service):
- with self.assertRaisesRegex(
+ with pytest.raises(
Exception,
- r"JSON is not a valid source format. Please use one of the following types: \['CSV', "
+ match=r"JSON is not a valid source format. Please use one of the following types: \['CSV', "
r"'NEWLINE_DELIMITED_JSON', 'AVRO', 'GOOGLE_SHEETS', 'DATASTORE_BACKUP', 'PARQUET'\]",
):
self.hook.run_load("test.test", "test_schema.json", ["test_data.json"], source_format="json")
@@ -702,7 +708,7 @@ def test_insert_all_fail(self, mock_client):
rows = [{"json": {"a_key": "a_value_0"}}]
mock_client.return_value.insert_rows.return_value = ["some", "errors"]
- with self.assertRaisesRegex(AirflowException, "insert error"):
+ with pytest.raises(AirflowException, match="insert error"):
self.hook.insert_all(
project_id=PROJECT_ID, dataset_id=DATASET_ID, table_id=TABLE_ID, rows=rows, fail_on_error=True
)
@@ -716,7 +722,7 @@ def test_run_query_with_arg(self, mock_insert):
)
_, kwargs = mock_insert.call_args
- self.assertEqual(kwargs["configuration"]['labels'], {'label1': 'test1', 'label2': 'test2'})
+ assert kwargs["configuration"]['labels'] == {'label1': 'test1', 'label2': 'test2'}
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.QueryJob")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client")
@@ -753,7 +759,7 @@ def test_insert_job(self, mock_client, mock_query_job):
class TestBigQueryTableSplitter(unittest.TestCase):
def test_internal_need_default_project(self):
- with self.assertRaisesRegex(Exception, "INTERNAL: No default project is specified"):
+ with pytest.raises(Exception, match="INTERNAL: No default project is specified"):
_split_tablename("dataset.table", None)
@parameterized.expand(
@@ -768,9 +774,9 @@ def test_internal_need_default_project(self):
def test_split_tablename(self, project_expected, dataset_expected, table_expected, table_input):
default_project_id = "project"
project, dataset, table = _split_tablename(table_input, default_project_id)
- self.assertEqual(project_expected, project)
- self.assertEqual(dataset_expected, dataset)
- self.assertEqual(table_expected, table)
+ assert project_expected == project
+ assert dataset_expected == dataset
+ assert table_expected == table
@parameterized.expand(
[
@@ -800,7 +806,7 @@ def test_split_tablename(self, project_expected, dataset_expected, table_expecte
)
def test_invalid_syntax(self, table_input, var_name, exception_message):
default_project_id = "project"
- with self.assertRaisesRegex(Exception, exception_message.format(table_input)):
+ with pytest.raises(Exception, match=exception_message.format(table_input)):
_split_tablename(table_input, default_project_id, var_name)
@@ -1019,20 +1025,20 @@ def test_execute_many(self, mock_insert, _):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_description(self, mock_get_service):
bq_cursor = self.hook.get_cursor()
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
bq_cursor.description
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_close(self, mock_get_service):
bq_cursor = self.hook.get_cursor()
result = bq_cursor.close() # pylint: disable=assignment-from-no-return
- self.assertIsNone(result)
+ assert result is None
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_rowcount(self, mock_get_service):
bq_cursor = self.hook.get_cursor()
result = bq_cursor.rowcount
- self.assertEqual(-1, result)
+ assert -1 == result
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor.next")
@@ -1040,7 +1046,7 @@ def test_fetchone(self, mock_next, mock_get_service):
bq_cursor = self.hook.get_cursor()
result = bq_cursor.fetchone()
mock_next.call_count == 1
- self.assertEqual(mock_next.return_value, result)
+ assert mock_next.return_value == result
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch(
@@ -1049,7 +1055,7 @@ def test_fetchone(self, mock_next, mock_get_service):
def test_fetchall(self, mock_fetchone, mock_get_service):
bq_cursor = self.hook.get_cursor()
result = bq_cursor.fetchall()
- self.assertEqual([1, 2, 3], result)
+ assert [1, 2, 3] == result
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor.fetchone")
@@ -1058,22 +1064,22 @@ def test_fetchmany(self, mock_fetchone, mock_get_service):
bq_cursor = self.hook.get_cursor()
mock_fetchone.side_effect = side_effect_values
result = bq_cursor.fetchmany()
- self.assertEqual([1], result)
+ assert [1] == result
mock_fetchone.side_effect = side_effect_values
result = bq_cursor.fetchmany(2)
- self.assertEqual([1, 2], result)
+ assert [1, 2] == result
mock_fetchone.side_effect = side_effect_values
result = bq_cursor.fetchmany(5)
- self.assertEqual([1, 2, 3], result)
+ assert [1, 2, 3] == result
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_next_no_jobid(self, mock_get_service):
bq_cursor = self.hook.get_cursor()
bq_cursor.job_id = None
result = bq_cursor.next()
- self.assertIsNone(result)
+ assert result is None
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_next_buffer(self, mock_get_service):
@@ -1081,12 +1087,12 @@ def test_next_buffer(self, mock_get_service):
bq_cursor.job_id = JOB_ID
bq_cursor.buffer = [1, 2]
result = bq_cursor.next()
- self.assertEqual(1, result)
+ assert 1 == result
result = bq_cursor.next()
- self.assertEqual(2, result)
+ assert 2 == result
bq_cursor.all_pages_loaded = True
result = bq_cursor.next()
- self.assertIsNone(result)
+ assert result is None
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_next(self, mock_get_service):
@@ -1111,10 +1117,10 @@ def test_next(self, mock_get_service):
bq_cursor.location = LOCATION
result = bq_cursor.next()
- self.assertEqual(['one', 1], result)
+ assert ['one', 1] == result
result = bq_cursor.next()
- self.assertEqual(['two', 2], result)
+ assert ['two', 2] == result
mock_get_query_results.assert_called_once_with(
jobId=JOB_ID, location=LOCATION, pageToken=None, projectId='bq-project'
@@ -1133,7 +1139,7 @@ def test_next_no_rows(self, mock_flush_results, mock_get_service):
result = bq_cursor.next()
- self.assertIsNone(result)
+ assert result is None
mock_get_query_results.assert_called_once_with(
jobId=JOB_ID, location=None, pageToken=None, projectId='bq-project'
)
@@ -1156,24 +1162,24 @@ def test_flush_cursor(self, mock_get_service):
bq_cursor.all_pages_loaded = True
bq_cursor.buffer = [('a', 100, 200), ('b', 200, 300)]
bq_cursor.flush_results()
- self.assertIsNone(bq_cursor.page_token)
- self.assertIsNone(bq_cursor.job_id)
- self.assertFalse(bq_cursor.all_pages_loaded)
- self.assertListEqual(bq_cursor.buffer, [])
+ assert bq_cursor.page_token is None
+ assert bq_cursor.job_id is None
+ assert not bq_cursor.all_pages_loaded
+ assert bq_cursor.buffer == []
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_arraysize(self, mock_get_service):
bq_cursor = self.hook.get_cursor()
- self.assertIsNone(bq_cursor.buffersize)
- self.assertEqual(bq_cursor.arraysize, 1)
+ assert bq_cursor.buffersize is None
+ assert bq_cursor.arraysize == 1
bq_cursor.set_arraysize(10)
- self.assertEqual(bq_cursor.buffersize, 10)
- self.assertEqual(bq_cursor.arraysize, 10)
+ assert bq_cursor.buffersize == 10
+ assert bq_cursor.arraysize == 10
class TestDatasetsOperations(_BigQueryBaseTestClass):
def test_create_empty_dataset_no_dataset_id_err(self):
- with self.assertRaisesRegex(ValueError, r"Please specify `datasetId`"):
+ with pytest.raises(ValueError, match=r"Please specify `datasetId`"):
self.hook.create_empty_dataset(dataset_id=None, project_id=None)
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.Dataset")
@@ -1242,7 +1248,7 @@ def test_get_dataset(self, mock_client):
dataset_ref=DatasetReference(PROJECT_ID, DATASET_ID)
)
- self.assertEqual(result, expected_result)
+ assert result == expected_result
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client")
def test_get_datasets_list(self, mock_client):
@@ -1342,14 +1348,14 @@ def test_run_load_default(self, mock_insert):
)
_, kwargs = mock_insert.call_args
- self.assertIs(kwargs["configuration"]['load'].get('timePartitioning'), None)
+ assert kwargs["configuration"]['load'].get('timePartitioning') is None
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_run_with_auto_detect(self, mock_insert):
destination_project_dataset_table = "autodetect.table"
self.hook.run_load(destination_project_dataset_table, [], [], autodetect=True)
_, kwargs = mock_insert.call_args
- self.assertIs(kwargs["configuration"]['load']['autodetect'], True)
+ assert kwargs["configuration"]['load']['autodetect'] is True
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_run_load_with_arg(self, mock_insert):
@@ -1406,7 +1412,7 @@ def test_run_query_with_arg(self, mock_insert):
def test_dollar_makes_partition(self):
tp_out = _cleanse_time_partitioning('test.teast$20170101', {})
expect = {'type': 'DAY'}
- self.assertEqual(tp_out, expect)
+ assert tp_out == expect
def test_extra_time_partitioning_options(self):
tp_out = _cleanse_time_partitioning(
@@ -1414,7 +1420,7 @@ def test_extra_time_partitioning_options(self):
)
expect = {'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000}
- self.assertEqual(tp_out, expect)
+ assert tp_out == expect
class TestClusteringInRunJob(_BigQueryBaseTestClass):
@@ -1427,7 +1433,7 @@ def test_run_load_default(self, mock_insert):
)
_, kwargs = mock_insert.call_args
- self.assertIsNone(kwargs["configuration"]['load'].get('clustering'))
+ assert kwargs["configuration"]['load'].get('clustering') is None
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_run_load_with_arg(self, mock_insert):
@@ -1440,14 +1446,14 @@ def test_run_load_with_arg(self, mock_insert):
)
_, kwargs = mock_insert.call_args
- self.assertEqual(kwargs["configuration"]['load']['clustering'], {'fields': ['field1', 'field2']})
+ assert kwargs["configuration"]['load']['clustering'] == {'fields': ['field1', 'field2']}
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_run_query_default(self, mock_insert):
self.hook.run_query(sql='select 1')
_, kwargs = mock_insert.call_args
- self.assertIsNone(kwargs["configuration"]['query'].get('clustering'))
+ assert kwargs["configuration"]['query'].get('clustering') is None
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_run_query_with_arg(self, mock_insert):
@@ -1459,7 +1465,7 @@ def test_run_query_with_arg(self, mock_insert):
)
_, kwargs = mock_insert.call_args
- self.assertEqual(kwargs["configuration"]['query']['clustering'], {'fields': ['field1', 'field2']})
+ assert kwargs["configuration"]['query']['clustering'] == {'fields': ['field1', 'field2']}
class TestBigQueryHookLegacySql(_BigQueryBaseTestClass):
@@ -1470,7 +1476,7 @@ class TestBigQueryHookLegacySql(_BigQueryBaseTestClass):
def test_hook_uses_legacy_sql_by_default(self, mock_insert, _):
self.hook.get_first('query')
_, kwargs = mock_insert.call_args
- self.assertIs(kwargs["configuration"]['query']['useLegacySql'], True)
+ assert kwargs["configuration"]['query']['useLegacySql'] is True
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id',
@@ -1484,7 +1490,7 @@ def test_legacy_sql_override_propagates_properly(
bq_hook = BigQueryHook(use_legacy_sql=False)
bq_hook.get_first('query')
_, kwargs = mock_insert.call_args
- self.assertIs(kwargs["configuration"]['query']['useLegacySql'], False)
+ assert kwargs["configuration"]['query']['useLegacySql'] is False
class TestBigQueryHookRunWithConfiguration(_BigQueryBaseTestClass):
@@ -1665,8 +1671,8 @@ def test_run_query_with_kms(self, mock_insert):
encryption_configuration = {"kms_key_name": "projects/p/locations/l/keyRings/k/cryptoKeys/c"}
self.hook.run_query(sql='query', encryption_configuration=encryption_configuration)
_, kwargs = mock_insert.call_args
- self.assertIs(
- kwargs["configuration"]['query']['destinationEncryptionConfiguration'], encryption_configuration
+ assert (
+ kwargs["configuration"]['query']['destinationEncryptionConfiguration'] is encryption_configuration
)
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
@@ -1678,8 +1684,8 @@ def test_run_copy_with_kms(self, mock_insert):
encryption_configuration=encryption_configuration,
)
_, kwargs = mock_insert.call_args
- self.assertIs(
- kwargs["configuration"]['copy']['destinationEncryptionConfiguration'], encryption_configuration
+ assert (
+ kwargs["configuration"]['copy']['destinationEncryptionConfiguration'] is encryption_configuration
)
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
@@ -1692,8 +1698,8 @@ def test_run_load_with_kms(self, mock_insert):
encryption_configuration=encryption_configuration,
)
_, kwargs = mock_insert.call_args
- self.assertIs(
- kwargs["configuration"]['load']['destinationEncryptionConfiguration'], encryption_configuration
+ assert (
+ kwargs["configuration"]['load']['destinationEncryptionConfiguration'] is encryption_configuration
)
@@ -1737,8 +1743,8 @@ def test_deprecation_warning(self, func_name, mock_bq_hook):
bq_cursor = BigQueryCursor(mock.MagicMock(), PROJECT_ID, mock_bq_hook)
func = getattr(bq_cursor, func_name)
- with self.assertWarnsRegex(DeprecationWarning, message_regex):
+ with pytest.warns(DeprecationWarning, match=message_regex):
_ = func(*args, **kwargs)
mocked_func.assert_called_once_with(*args, **kwargs)
- self.assertRegex(func.__doc__, f".*{new_path}.*")
+ assert re.search(f".*{new_path}.*", func.__doc__)
diff --git a/tests/providers/google/cloud/hooks/test_bigquery_dts.py b/tests/providers/google/cloud/hooks/test_bigquery_dts.py
index c4c7c186073f7..b53cb7637e1e3 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery_dts.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery_dts.py
@@ -59,12 +59,12 @@ def setUp(self) -> None:
def test_version_information(self):
expected_version = "airflow_v" + version
- self.assertEqual(expected_version, self.hook.client_info.client_library_version)
+ assert expected_version == self.hook.client_info.client_library_version
def test_disable_auto_scheduling(self):
expected = deepcopy(TRANSFER_CONFIG)
expected.schedule_options.disable_auto_scheduling = True
- self.assertEqual(expected, self.hook._disable_auto_scheduling(TRANSFER_CONFIG))
+ assert expected == self.hook._disable_auto_scheduling(TRANSFER_CONFIG)
@mock.patch(
"airflow.providers.google.cloud.hooks.bigquery_dts."
diff --git a/tests/providers/google/cloud/hooks/test_bigquery_system.py b/tests/providers/google/cloud/hooks/test_bigquery_system.py
index 78978c1a938df..4cfaa60c2a5ca 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery_system.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery_system.py
@@ -35,22 +35,22 @@ def test_output_is_dataframe_with_valid_query(self):
import pandas as pd
df = self.instance.get_pandas_df('select 1')
- self.assertIsInstance(df, pd.DataFrame)
+ assert isinstance(df, pd.DataFrame)
def test_throws_exception_with_invalid_query(self):
- with self.assertRaises(Exception) as context:
+ with pytest.raises(Exception) as ctx:
self.instance.get_pandas_df('from `1`')
- self.assertIn('Reason: ', str(context.exception), "")
+ assert 'Reason: ' in str(ctx.value), ""
def test_succeeds_with_explicit_legacy_query(self):
df = self.instance.get_pandas_df('select 1', dialect='legacy')
- self.assertEqual(df.iloc(0)[0][0], 1)
+ assert df.iloc(0)[0][0] == 1
def test_succeeds_with_explicit_std_query(self):
df = self.instance.get_pandas_df('select * except(b) from (select 1 a, 2 b)', dialect='standard')
- self.assertEqual(df.iloc(0)[0][0], 1)
+ assert df.iloc(0)[0][0] == 1
def test_throws_exception_with_incompatible_syntax(self):
- with self.assertRaises(Exception) as context:
+ with pytest.raises(Exception) as ctx:
self.instance.get_pandas_df('select * except(b) from (select 1 a, 2 b)', dialect='legacy')
- self.assertIn('Reason: ', str(context.exception), "")
+ assert 'Reason: ' in str(ctx.value), ""
diff --git a/tests/providers/google/cloud/hooks/test_bigtable.py b/tests/providers/google/cloud/hooks/test_bigtable.py
index a8809ed651f83..a452c483a82e3 100644
--- a/tests/providers/google/cloud/hooks/test_bigtable.py
+++ b/tests/providers/google/cloud/hooks/test_bigtable.py
@@ -70,8 +70,8 @@ def test_bigtable_client_creation(self, mock_client, mock_get_creds, mock_client
client_info=mock_client_info.return_value,
admin=True,
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.bigtable_hook_no_default_project_id._client, result)
+ assert mock_client.return_value == result
+ assert self.bigtable_hook_no_default_project_id._client == result
@mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client')
def test_get_instance_overridden_project_id(self, get_client):
@@ -84,7 +84,7 @@ def test_get_instance_overridden_project_id(self, get_client):
instance_method.assert_called_once_with('instance')
instance_exists_method.assert_called_once_with()
get_client.assert_called_once_with(project_id='example-project')
- self.assertIsNotNone(res)
+ assert res is not None
@mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client')
def test_delete_instance_overridden_project_id(self, get_client):
@@ -99,7 +99,7 @@ def test_delete_instance_overridden_project_id(self, get_client):
instance_exists_method.assert_called_once_with()
delete_method.assert_called_once_with()
get_client.assert_called_once_with(project_id='example-project')
- self.assertIsNone(res)
+ assert res is None
@mock.patch('google.cloud.bigtable.instance.Instance.create')
@mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client')
@@ -115,7 +115,7 @@ def test_create_instance_overridden_project_id(self, get_client, instance_create
)
get_client.assert_called_once_with(project_id='example-project')
instance_create.assert_called_once_with(clusters=mock.ANY)
- self.assertEqual(res.instance_id, 'instance')
+ assert res.instance_id == 'instance'
@mock.patch('google.cloud.bigtable.instance.Instance.update')
@mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client')
@@ -132,7 +132,7 @@ def test_update_instance_overridden_project_id(self, get_client, instance_update
)
get_client.assert_called_once_with(project_id='example-project')
instance_update.assert_called_once_with()
- self.assertEqual(res.instance_id, 'instance')
+ assert res.instance_id == 'instance'
@mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client')
def test_delete_table_overridden_project_id(self, get_client):
@@ -170,8 +170,8 @@ def test_bigtable_client_creation(self, mock_client, mock_get_creds, mock_client
client_info=mock_client_info.return_value,
admin=True,
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.bigtable_hook_default_project_id._client, result)
+ assert mock_client.return_value == result
+ assert self.bigtable_hook_default_project_id._client == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -190,7 +190,7 @@ def test_get_instance(self, get_client, mock_project_id):
instance_method.assert_called_once_with('instance')
instance_exists_method.assert_called_once_with()
get_client.assert_called_once_with(project_id='example-project')
- self.assertIsNotNone(res)
+ assert res is not None
@mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client')
def test_get_instance_overridden_project_id(self, get_client):
@@ -203,7 +203,7 @@ def test_get_instance_overridden_project_id(self, get_client):
instance_method.assert_called_once_with('instance')
instance_exists_method.assert_called_once_with()
get_client.assert_called_once_with(project_id='new-project')
- self.assertIsNotNone(res)
+ assert res is not None
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -222,7 +222,7 @@ def test_get_instance_no_instance(self, get_client, mock_project_id):
instance_method.assert_called_once_with('instance')
instance_exists_method.assert_called_once_with()
get_client.assert_called_once_with(project_id='example-project')
- self.assertIsNone(res)
+ assert res is None
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -243,7 +243,7 @@ def test_delete_instance(self, get_client, mock_project_id):
instance_exists_method.assert_called_once_with()
delete_method.assert_called_once_with()
get_client.assert_called_once_with(project_id='example-project')
- self.assertIsNone(res)
+ assert res is None
@mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client')
def test_delete_instance_overridden_project_id(self, get_client):
@@ -258,7 +258,7 @@ def test_delete_instance_overridden_project_id(self, get_client):
instance_exists_method.assert_called_once_with()
delete_method.assert_called_once_with()
get_client.assert_called_once_with(project_id='new-project')
- self.assertIsNone(res)
+ assert res is None
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -299,7 +299,7 @@ def test_create_instance(self, get_client, instance_create, mock_project_id):
)
get_client.assert_called_once_with(project_id='example-project')
instance_create.assert_called_once_with(clusters=mock.ANY)
- self.assertEqual(res.instance_id, 'instance')
+ assert res.instance_id == 'instance'
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -337,7 +337,7 @@ def test_create_instance_with_one_replica_cluster(
)
get_client.assert_called_once_with(project_id='example-project')
instance_create.assert_called_once_with(clusters=mock.ANY)
- self.assertEqual(res.instance_id, 'instance')
+ assert res.instance_id == 'instance'
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -374,7 +374,7 @@ def test_create_instance_with_multiple_replica_clusters(
)
get_client.assert_called_once_with(project_id='example-project')
instance_create.assert_called_once_with(clusters=mock.ANY)
- self.assertEqual(res.instance_id, 'instance')
+ assert res.instance_id == 'instance'
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -396,7 +396,7 @@ def test_update_instance(self, get_client, instance_update, mock_project_id):
)
get_client.assert_called_once_with(project_id='example-project')
instance_update.assert_called_once_with()
- self.assertEqual(res.instance_id, 'instance')
+ assert res.instance_id == 'instance'
@mock.patch('google.cloud.bigtable.instance.Instance.create')
@mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client')
@@ -412,7 +412,7 @@ def test_create_instance_overridden_project_id(self, get_client, instance_create
)
get_client.assert_called_once_with(project_id='new-project')
instance_create.assert_called_once_with(clusters=mock.ANY)
- self.assertEqual(res.instance_id, 'instance')
+ assert res.instance_id == 'instance'
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
diff --git a/tests/providers/google/cloud/hooks/test_cloud_build.py b/tests/providers/google/cloud/hooks/test_cloud_build.py
index 9a3b67ec258e6..ebf7163206757 100644
--- a/tests/providers/google/cloud/hooks/test_cloud_build.py
+++ b/tests/providers/google/cloud/hooks/test_cloud_build.py
@@ -23,6 +23,8 @@
from unittest import mock
from unittest.mock import PropertyMock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.cloud_build import CloudBuildHook
from tests.providers.google.cloud.utils.base_gcp_mock import (
@@ -63,8 +65,8 @@ def test_cloud_build_client_creation(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'cloudbuild', 'v1', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(mock_build.return_value, result)
- self.assertEqual(self.hook._conn, result)
+ assert mock_build.return_value == result
+ assert self.hook._conn == result
@mock.patch("airflow.providers.google.cloud.hooks.cloud_build.CloudBuildHook.get_conn")
def test_build_immediately_complete(self, get_conn_mock):
@@ -86,7 +88,7 @@ def test_build_immediately_complete(self, get_conn_mock):
body={}, projectId=TEST_PROJECT_ID
)
- self.assertEqual(result, TEST_BUILD)
+ assert result == TEST_BUILD
@mock.patch("airflow.providers.google.cloud.hooks.cloud_build.CloudBuildHook.get_conn")
@mock.patch("airflow.providers.google.cloud.hooks.cloud_build.time.sleep")
@@ -108,7 +110,7 @@ def test_waiting_operation(self, _, get_conn_mock):
result = self.hook.create_build(body={}, project_id=TEST_PROJECT_ID)
- self.assertEqual(result, TEST_BUILD)
+ assert result == TEST_BUILD
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -126,7 +128,7 @@ def test_error_operation(self, _, get_conn_mock, mock_project_id):
execute_mock = mock.Mock(**{"side_effect": [TEST_WAITING_OPERATION, TEST_ERROR_OPERATION]})
service_mock.operations.return_value.get.return_value.execute = execute_mock
- with self.assertRaisesRegex(AirflowException, "error"):
+ with pytest.raises(AirflowException, match="error"):
self.hook.create_build(body={}) # pylint: disable=no-value-for-parameter
@@ -147,8 +149,8 @@ def test_cloud_build_client_creation(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'cloudbuild', 'v1', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(mock_build.return_value, result)
- self.assertEqual(self.hook._conn, result)
+ assert mock_build.return_value == result
+ assert self.hook._conn == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -175,7 +177,7 @@ def test_build_immediately_complete(self, get_conn_mock, mock_project_id):
body={}, projectId='example-project'
)
- self.assertEqual(result, TEST_BUILD)
+ assert result == TEST_BUILD
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -202,7 +204,7 @@ def test_waiting_operation(self, _, get_conn_mock, mock_project_id):
result = self.hook.create_build(body={}) # pylint: disable=no-value-for-parameter
- self.assertEqual(result, TEST_BUILD)
+ assert result == TEST_BUILD
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -220,7 +222,7 @@ def test_error_operation(self, _, get_conn_mock, mock_project_id):
execute_mock = mock.Mock(**{"side_effect": [TEST_WAITING_OPERATION, TEST_ERROR_OPERATION]})
service_mock.operations.return_value.get.return_value.execute = execute_mock
- with self.assertRaisesRegex(AirflowException, "error"):
+ with pytest.raises(AirflowException, match="error"):
self.hook.create_build(body={}) # pylint: disable=no-value-for-parameter
@@ -241,8 +243,8 @@ def test_cloud_build_client_creation(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'cloudbuild', 'v1', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(mock_build.return_value, result)
- self.assertEqual(self.hook._conn, result)
+ assert mock_build.return_value == result
+ assert self.hook._conn == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -251,11 +253,10 @@ def test_cloud_build_client_creation(self, mock_build, mock_authorize):
)
@mock.patch("airflow.providers.google.cloud.hooks.cloud_build.CloudBuildHook.get_conn")
def test_create_build(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.hook.create_build(body={}) # pylint: disable=no-value-for-parameter
- self.assertEqual(
+ assert (
"The project id must be passed either as keyword project_id parameter or as project_id extra in "
- "Google Cloud connection definition. Both are not set!",
- str(e.exception),
+ "Google Cloud connection definition. Both are not set!" == str(ctx.value)
)
diff --git a/tests/providers/google/cloud/hooks/test_cloud_memorystore.py b/tests/providers/google/cloud/hooks/test_cloud_memorystore.py
index 64050ca400573..9e6f442236b4a 100644
--- a/tests/providers/google/cloud/hooks/test_cloud_memorystore.py
+++ b/tests/providers/google/cloud/hooks/test_cloud_memorystore.py
@@ -19,6 +19,7 @@
from unittest import TestCase, mock
from unittest.mock import PropertyMock
+import pytest
from google.api_core.retry import Retry
from google.cloud.exceptions import NotFound
from google.cloud.memcache_v1beta2.types import cloud_memcache
@@ -89,7 +90,7 @@ def test_create_instance_when_exists(self, mock_get_conn, mock_project_id):
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
- self.assertEqual(Instance(name=TEST_NAME), result)
+ assert Instance(name=TEST_NAME) == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -130,7 +131,7 @@ def test_create_instance_when_not_exists(self, mock_get_conn, mock_project_id):
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
)
- self.assertEqual(Instance(name=TEST_NAME), result)
+ assert Instance(name=TEST_NAME) == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -247,7 +248,7 @@ def test_create_instance_when_exists(self, mock_get_conn):
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
- self.assertEqual(Instance(name=TEST_NAME), result)
+ assert Instance(name=TEST_NAME) == result
@mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn")
def test_create_instance_when_not_exists(self, mock_get_conn):
@@ -295,7 +296,7 @@ def test_create_instance_when_not_exists(self, mock_get_conn):
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
)
- self.assertEqual(Instance(name=TEST_NAME), result)
+ assert Instance(name=TEST_NAME) == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -304,7 +305,7 @@ def test_create_instance_when_not_exists(self, mock_get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn")
def test_create_instance_without_project_id(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.create_instance(
location=TEST_LOCATION,
instance_id=TEST_INSTANCE_ID,
@@ -336,7 +337,7 @@ def test_delete_instance(self, mock_get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn")
def test_delete_instance_without_project_id(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.delete_instance(
location=TEST_LOCATION,
instance=Instance(name=TEST_NAME),
@@ -367,7 +368,7 @@ def test_get_instance(self, mock_get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn")
def test_get_instance_without_project_id(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.get_instance(
location=TEST_LOCATION,
instance=Instance(name=TEST_NAME),
@@ -401,7 +402,7 @@ def test_list_instances(self, mock_get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn")
def test_list_instances_without_project_id(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.list_instances(
location=TEST_LOCATION,
page_size=TEST_PAGE_SIZE,
@@ -435,7 +436,7 @@ def test_update_instance(self, mock_get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.cloud_memorystore.CloudMemorystoreHook.get_conn")
def test_update_instance_without_project_id(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.update_instance( # pylint: disable=no-value-for-parameter
update_mask=TEST_UPDATE_MASK,
instance=Instance(name=TEST_NAME),
@@ -476,7 +477,7 @@ def test_create_instance_when_exists(self, mock_get_conn, mock_project_id):
mock_get_conn.return_value.get_instance.assert_called_once_with(
name=TEST_NAME_DEFAULT_PROJECT_ID, retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA
)
- self.assertEqual(cloud_memcache.Instance(name=TEST_NAME), result)
+ assert cloud_memcache.Instance(name=TEST_NAME) == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -519,7 +520,7 @@ def test_create_instance_when_not_exists(self, mock_get_conn, mock_project_id):
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
)
- self.assertEqual(cloud_memcache.Instance(name=TEST_NAME), result)
+ assert cloud_memcache.Instance(name=TEST_NAME) == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
diff --git a/tests/providers/google/cloud/hooks/test_cloud_sql.py b/tests/providers/google/cloud/hooks/test_cloud_sql.py
index 7aa691f431a08..003245d8eddba 100644
--- a/tests/providers/google/cloud/hooks/test_cloud_sql.py
+++ b/tests/providers/google/cloud/hooks/test_cloud_sql.py
@@ -24,6 +24,7 @@
from unittest.mock import PropertyMock
import httplib2
+import pytest
from googleapiclient.errors import HttpError
from parameterized import parameterized
@@ -52,13 +53,13 @@ def test_instance_import_exception(self, mock_get_credentials):
self.cloudsql_hook.get_conn = mock.Mock(
side_effect=HttpError(resp=httplib2.Response({'status': 400}), content=b'Error content')
)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
self.cloudsql_hook.import_instance( # pylint: disable=no-value-for-parameter
instance='instance', body={}
)
- err = cm.exception
- self.assertIn("Importing instance ", str(err))
- self.assertEqual(1, mock_get_credentials.call_count)
+ err = ctx.value
+ assert "Importing instance " in str(err)
+ assert 1 == mock_get_credentials.call_count
@mock.patch(
'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id',
@@ -68,13 +69,13 @@ def test_instance_export_exception(self, mock_get_credentials):
self.cloudsql_hook.get_conn = mock.Mock(
side_effect=HttpError(resp=httplib2.Response({'status': 400}), content=b'Error content')
)
- with self.assertRaises(HttpError) as cm:
+ with pytest.raises(HttpError) as ctx:
self.cloudsql_hook.export_instance( # pylint: disable=no-value-for-parameter
instance='instance', body={}
)
- err = cm.exception
- self.assertEqual(400, err.resp.status)
- self.assertEqual(1, mock_get_credentials.call_count)
+ err = ctx.value
+ assert 400 == err.resp.status
+ assert 1 == mock_get_credentials.call_count
@mock.patch(
'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id',
@@ -96,7 +97,7 @@ def test_instance_import(self, wait_for_operation_to_complete, get_conn, mock_ge
wait_for_operation_to_complete.assert_called_once_with(
project_id='example-project', operation_name='operation_id'
)
- self.assertEqual(1, mock_get_credentials.call_count)
+ assert 1 == mock_get_credentials.call_count
@mock.patch(
'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id',
@@ -118,7 +119,7 @@ def test_instance_export(self, wait_for_operation_to_complete, get_conn, mock_ge
wait_for_operation_to_complete.assert_called_once_with(
project_id='example-project', operation_name='operation_id'
)
- self.assertEqual(1, mock_get_credentials.call_count)
+ assert 1 == mock_get_credentials.call_count
@mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn')
@mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._wait_for_operation_to_complete')
@@ -141,8 +142,8 @@ def test_instance_export_with_in_progress_retry(self, wait_for_operation_to_comp
wait_for_operation_to_complete.return_value = None
self.cloudsql_hook.export_instance(project_id='example-project', instance='instance', body={})
- self.assertEqual(2, export_method.call_count)
- self.assertEqual(2, execute_method.call_count)
+ assert 2 == export_method.call_count
+ assert 2 == execute_method.call_count
wait_for_operation_to_complete.assert_called_once_with(
project_id='example-project', operation_name='operation_id'
)
@@ -159,12 +160,12 @@ def test_get_instance(self, wait_for_operation_to_complete, get_conn, mock_get_c
execute_method.return_value = {"name": "instance"}
wait_for_operation_to_complete.return_value = None
res = self.cloudsql_hook.get_instance(instance='instance') # pylint: disable=no-value-for-parameter
- self.assertIsNotNone(res)
- self.assertEqual('instance', res['name'])
+ assert res is not None
+ assert 'instance' == res['name']
get_method.assert_called_once_with(instance='instance', project='example-project')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_not_called()
- self.assertEqual(1, mock_get_credentials.call_count)
+ assert 1 == mock_get_credentials.call_count
@mock.patch(
'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id',
@@ -184,7 +185,7 @@ def test_create_instance(self, wait_for_operation_to_complete, get_conn, mock_ge
wait_for_operation_to_complete.assert_called_once_with(
operation_name='operation_id', project_id='example-project'
)
- self.assertEqual(1, mock_get_credentials.call_count)
+ assert 1 == mock_get_credentials.call_count
@mock.patch(
'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id',
@@ -213,9 +214,9 @@ def test_create_instance_with_in_progress_retry(
wait_for_operation_to_complete.return_value = None
self.cloudsql_hook.create_instance(body={}) # pylint: disable=no-value-for-parameter
- self.assertEqual(1, mock_get_credentials.call_count)
- self.assertEqual(2, insert_method.call_count)
- self.assertEqual(2, execute_method.call_count)
+ assert 1 == mock_get_credentials.call_count
+ assert 2 == insert_method.call_count
+ assert 2 == execute_method.call_count
wait_for_operation_to_complete.assert_called_once_with(
operation_name='operation_id', project_id='example-project'
)
@@ -249,9 +250,9 @@ def test_patch_instance_with_in_progress_retry(
instance='instance', body={}
)
- self.assertEqual(1, mock_get_credentials.call_count)
- self.assertEqual(2, patch_method.call_count)
- self.assertEqual(2, execute_method.call_count)
+ assert 1 == mock_get_credentials.call_count
+ assert 2 == patch_method.call_count
+ assert 2 == execute_method.call_count
wait_for_operation_to_complete.assert_called_once_with(
operation_name='operation_id', project_id='example-project'
)
@@ -276,7 +277,7 @@ def test_patch_instance(self, wait_for_operation_to_complete, get_conn, mock_get
wait_for_operation_to_complete.assert_called_once_with(
operation_name='operation_id', project_id='example-project'
)
- self.assertEqual(1, mock_get_credentials.call_count)
+ assert 1 == mock_get_credentials.call_count
@mock.patch(
'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id',
@@ -296,7 +297,7 @@ def test_delete_instance(self, wait_for_operation_to_complete, get_conn, mock_ge
wait_for_operation_to_complete.assert_called_once_with(
operation_name='operation_id', project_id='example-project'
)
- self.assertEqual(1, mock_get_credentials.call_count)
+ assert 1 == mock_get_credentials.call_count
@mock.patch(
'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id',
@@ -325,9 +326,9 @@ def test_delete_instance_with_in_progress_retry(
wait_for_operation_to_complete.return_value = None
self.cloudsql_hook.delete_instance(instance='instance') # pylint: disable=no-value-for-parameter
- self.assertEqual(1, mock_get_credentials.call_count)
- self.assertEqual(2, delete_method.call_count)
- self.assertEqual(2, execute_method.call_count)
+ assert 1 == mock_get_credentials.call_count
+ assert 2 == delete_method.call_count
+ assert 2 == execute_method.call_count
wait_for_operation_to_complete.assert_called_once_with(
operation_name='operation_id', project_id='example-project'
)
@@ -346,14 +347,14 @@ def test_get_database(self, wait_for_operation_to_complete, get_conn, mock_get_c
res = self.cloudsql_hook.get_database( # pylint: disable=no-value-for-parameter
database='database', instance='instance'
)
- self.assertIsNotNone(res)
- self.assertEqual('database', res['name'])
+ assert res is not None
+ assert 'database' == res['name']
get_method.assert_called_once_with(
instance='instance', database='database', project='example-project'
)
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_not_called()
- self.assertEqual(1, mock_get_credentials.call_count)
+ assert 1 == mock_get_credentials.call_count
@mock.patch(
'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id',
@@ -375,7 +376,7 @@ def test_create_database(self, wait_for_operation_to_complete, get_conn, mock_ge
wait_for_operation_to_complete.assert_called_once_with(
operation_name='operation_id', project_id='example-project'
)
- self.assertEqual(1, mock_get_credentials.call_count)
+ assert 1 == mock_get_credentials.call_count
@mock.patch(
'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id',
@@ -406,9 +407,9 @@ def test_create_database_with_in_progress_retry(
instance='instance', body={}
)
- self.assertEqual(1, mock_get_credentials.call_count)
- self.assertEqual(2, insert_method.call_count)
- self.assertEqual(2, execute_method.call_count)
+ assert 1 == mock_get_credentials.call_count
+ assert 2 == insert_method.call_count
+ assert 2 == execute_method.call_count
wait_for_operation_to_complete.assert_called_once_with(
operation_name='operation_id', project_id='example-project'
)
@@ -435,7 +436,7 @@ def test_patch_database(self, wait_for_operation_to_complete, get_conn, mock_get
wait_for_operation_to_complete.assert_called_once_with(
operation_name='operation_id', project_id='example-project'
)
- self.assertEqual(1, mock_get_credentials.call_count)
+ assert 1 == mock_get_credentials.call_count
@mock.patch(
'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id',
@@ -466,9 +467,9 @@ def test_patch_database_with_in_progress_retry(
instance='instance', database='database', body={}
)
- self.assertEqual(1, mock_get_credentials.call_count)
- self.assertEqual(2, patch_method.call_count)
- self.assertEqual(2, execute_method.call_count)
+ assert 1 == mock_get_credentials.call_count
+ assert 2 == patch_method.call_count
+ assert 2 == execute_method.call_count
wait_for_operation_to_complete.assert_called_once_with(
operation_name='operation_id', project_id='example-project'
)
@@ -495,7 +496,7 @@ def test_delete_database(self, wait_for_operation_to_complete, get_conn, mock_ge
wait_for_operation_to_complete.assert_called_once_with(
operation_name='operation_id', project_id='example-project'
)
- self.assertEqual(1, mock_get_credentials.call_count)
+ assert 1 == mock_get_credentials.call_count
@mock.patch(
'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook._get_credentials_and_project_id',
@@ -526,9 +527,9 @@ def test_delete_database_with_in_progress_retry(
instance='instance', database='database'
)
- self.assertEqual(1, mock_get_credentials.call_count)
- self.assertEqual(2, delete_method.call_count)
- self.assertEqual(2, execute_method.call_count)
+ assert 1 == mock_get_credentials.call_count
+ assert 2 == delete_method.call_count
+ assert 2 == execute_method.call_count
wait_for_operation_to_complete.assert_called_once_with(
operation_name='operation_id', project_id='example-project'
)
@@ -605,8 +606,8 @@ def test_get_instance_overridden_project_id(
res = self.cloudsql_hook_no_default_project_id.get_instance(
project_id='example-project', instance='instance'
)
- self.assertIsNotNone(res)
- self.assertEqual('instance', res['name'])
+ assert res is not None
+ assert 'instance' == res['name']
get_method.assert_called_once_with(instance='instance', project='example-project')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_not_called()
@@ -695,8 +696,8 @@ def test_get_database_overridden_project_id(
res = self.cloudsql_hook_no_default_project_id.get_database(
project_id='example-project', database='database', instance='instance'
)
- self.assertIsNotNone(res)
- self.assertEqual('database', res['name'])
+ assert res is not None
+ assert 'database' == res['name']
get_method.assert_called_once_with(
instance='instance', database='database', project='example-project'
)
@@ -816,10 +817,10 @@ def test_cloudsql_database_hook_validate_ssl_certs_missing_cert_params(
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection'
)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
hook.validate_ssl_certs()
- err = cm.exception
- self.assertIn("SSL connections requires", str(err))
+ err = ctx.value
+ assert "SSL connections requires" in str(err)
@mock.patch('os.path.isfile')
@mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection')
@@ -869,10 +870,10 @@ def test_cloudsql_database_hook_validate_ssl_certs_with_ssl_files_not_readable(
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection'
)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
hook.validate_ssl_certs()
- err = cm.exception
- self.assertIn("must be a readable file", str(err))
+ err = ctx.value
+ assert "must be a readable file" in str(err)
@mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection')
def test_cloudsql_database_hook_validate_socket_path_length_too_long(self, get_connection):
@@ -892,10 +893,10 @@ def test_cloudsql_database_hook_validate_socket_path_length_too_long(self, get_c
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection'
)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
hook.validate_socket_path_length()
- err = cm.exception
- self.assertIn("The UNIX socket path length cannot exceed", str(err))
+ err = ctx.value
+ assert "The UNIX socket path length cannot exceed" in str(err)
@mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection')
def test_cloudsql_database_hook_validate_socket_path_length_not_too_long(self, get_connection):
@@ -943,10 +944,10 @@ def test_cloudsql_database_hook_create_connection_missing_fields(self, uri, get_
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection'
)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
hook.create_connection()
- err = cm.exception
- self.assertIn("needs to be set in connection", str(err))
+ err = ctx.value
+ assert "needs to be set in connection" in str(err)
@mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection')
def test_cloudsql_database_hook_get_sqlproxy_runner_no_proxy(self, get_connection):
@@ -964,10 +965,10 @@ def test_cloudsql_database_hook_get_sqlproxy_runner_no_proxy(self, get_connectio
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id='cloudsql_connection', default_gcp_project_id='google_connection'
)
- with self.assertRaises(ValueError) as cm:
+ with pytest.raises(ValueError) as ctx:
hook.get_sqlproxy_runner()
- err = cm.exception
- self.assertIn('Proxy runner can only be retrieved in case of use_proxy = True', str(err))
+ err = ctx.value
+ assert 'Proxy runner can only be retrieved in case of use_proxy = True' in str(err)
@mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection')
def test_cloudsql_database_hook_get_sqlproxy_runner(self, get_connection):
@@ -989,7 +990,7 @@ def test_cloudsql_database_hook_get_sqlproxy_runner(self, get_connection):
)
hook.create_connection()
proxy_runner = hook.get_sqlproxy_runner()
- self.assertIsNotNone(proxy_runner)
+ assert proxy_runner is not None
@mock.patch('airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection')
def test_cloudsql_database_hook_get_database_hook(self, get_connection):
@@ -1009,7 +1010,7 @@ def test_cloudsql_database_hook_get_database_hook(self, get_connection):
)
connection = hook.create_connection()
db_hook = hook.get_database_hook(connection=connection)
- self.assertIsNotNone(db_hook)
+ assert db_hook is not None
class TestCloudSqlDatabaseQueryHook(unittest.TestCase):
@@ -1056,14 +1057,14 @@ def setUp(self, m):
def test_get_sqlproxy_runner(self):
self.db_hook._generate_connection_uri()
sqlproxy_runner = self.db_hook.get_sqlproxy_runner()
- self.assertEqual(sqlproxy_runner.gcp_conn_id, self.connection.conn_id)
+ assert sqlproxy_runner.gcp_conn_id == self.connection.conn_id
project = self.sql_connection.extra_dejson['project_id']
location = self.sql_connection.extra_dejson['location']
instance = self.sql_connection.extra_dejson['instance']
instance_spec = "{project}:{location}:{instance}".format(
project=project, location=location, instance=instance
)
- self.assertEqual(sqlproxy_runner.instance_specification, instance_spec)
+ assert sqlproxy_runner.instance_specification == instance_spec
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_hook_with_not_too_long_unix_socket_path(self, get_connection):
@@ -1077,17 +1078,17 @@ def test_hook_with_not_too_long_unix_socket_path(self, get_connection):
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
- self.assertEqual('postgres', connection.conn_type)
- self.assertEqual('testdb', connection.schema)
+ assert 'postgres' == connection.conn_type
+ assert 'testdb' == connection.schema
def _verify_postgres_connection(self, get_connection, uri):
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
- self.assertEqual('postgres', connection.conn_type)
- self.assertEqual('127.0.0.1', connection.host)
- self.assertEqual(3200, connection.port)
- self.assertEqual('testdb', connection.schema)
+ assert 'postgres' == connection.conn_type
+ assert '127.0.0.1' == connection.host
+ assert 3200 == connection.port
+ assert 'testdb' == connection.schema
return connection
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
@@ -1108,9 +1109,9 @@ def test_hook_with_correct_parameters_postgres_ssl(self, get_connection):
"sslkey=/bin/bash&sslrootcert=/bin/bash"
)
connection = self._verify_postgres_connection(get_connection, uri)
- self.assertEqual('/bin/bash', connection.extra_dejson['sslkey'])
- self.assertEqual('/bin/bash', connection.extra_dejson['sslcert'])
- self.assertEqual('/bin/bash', connection.extra_dejson['sslrootcert'])
+ assert '/bin/bash' == connection.extra_dejson['sslkey']
+ assert '/bin/bash' == connection.extra_dejson['sslcert']
+ assert '/bin/bash' == connection.extra_dejson['sslrootcert']
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connection):
@@ -1122,11 +1123,11 @@ def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connection
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
- self.assertEqual('postgres', connection.conn_type)
- self.assertIn('/tmp', connection.host)
- self.assertIn('example-project:europe-west1:testdb', connection.host)
- self.assertIsNone(connection.port)
- self.assertEqual('testdb', connection.schema)
+ assert 'postgres' == connection.conn_type
+ assert '/tmp' in connection.host
+ assert 'example-project:europe-west1:testdb' in connection.host
+ assert connection.port is None
+ assert 'testdb' == connection.schema
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_hook_with_correct_parameters_project_id_missing(self, get_connection):
@@ -1141,10 +1142,10 @@ def verify_mysql_connection(self, get_connection, uri):
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
- self.assertEqual('mysql', connection.conn_type)
- self.assertEqual('127.0.0.1', connection.host)
- self.assertEqual(3200, connection.port)
- self.assertEqual('testdb', connection.schema)
+ assert 'mysql' == connection.conn_type
+ assert '127.0.0.1' == connection.host
+ assert 3200 == connection.port
+ assert 'testdb' == connection.schema
return connection
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
@@ -1157,10 +1158,10 @@ def test_hook_with_correct_parameters_postgres_proxy_tcp(self, get_connection):
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
- self.assertEqual('postgres', connection.conn_type)
- self.assertEqual('127.0.0.1', connection.host)
- self.assertNotEqual(3200, connection.port)
- self.assertEqual('testdb', connection.schema)
+ assert 'postgres' == connection.conn_type
+ assert '127.0.0.1' == connection.host
+ assert 3200 != connection.port
+ assert 'testdb' == connection.schema
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_hook_with_correct_parameters_mysql(self, get_connection):
@@ -1180,9 +1181,9 @@ def test_hook_with_correct_parameters_mysql_ssl(self, get_connection):
"sslkey=/bin/bash&sslrootcert=/bin/bash"
)
connection = self.verify_mysql_connection(get_connection, uri)
- self.assertEqual('/bin/bash', json.loads(connection.extra_dejson['ssl'])['cert'])
- self.assertEqual('/bin/bash', json.loads(connection.extra_dejson['ssl'])['key'])
- self.assertEqual('/bin/bash', json.loads(connection.extra_dejson['ssl'])['ca'])
+ assert '/bin/bash' == json.loads(connection.extra_dejson['ssl'])['cert']
+ assert '/bin/bash' == json.loads(connection.extra_dejson['ssl'])['key']
+ assert '/bin/bash' == json.loads(connection.extra_dejson['ssl'])['ca']
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connection):
@@ -1194,12 +1195,12 @@ def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connection):
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
- self.assertEqual('mysql', connection.conn_type)
- self.assertEqual('localhost', connection.host)
- self.assertIn('/tmp', connection.extra_dejson['unix_socket'])
- self.assertIn('example-project:europe-west1:testdb', connection.extra_dejson['unix_socket'])
- self.assertIsNone(connection.port)
- self.assertEqual('testdb', connection.schema)
+ assert 'mysql' == connection.conn_type
+ assert 'localhost' == connection.host
+ assert '/tmp' in connection.extra_dejson['unix_socket']
+ assert 'example-project:europe-west1:testdb' in connection.extra_dejson['unix_socket']
+ assert connection.port is None
+ assert 'testdb' == connection.schema
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_hook_with_correct_parameters_mysql_tcp(self, get_connection):
@@ -1211,7 +1212,7 @@ def test_hook_with_correct_parameters_mysql_tcp(self, get_connection):
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
- self.assertEqual('mysql', connection.conn_type)
- self.assertEqual('127.0.0.1', connection.host)
- self.assertNotEqual(3200, connection.port)
- self.assertEqual('testdb', connection.schema)
+ assert 'mysql' == connection.conn_type
+ assert '127.0.0.1' == connection.host
+ assert 3200 != connection.port
+ assert 'testdb' == connection.schema
diff --git a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py
index b7d869b535cb4..3b423177ee26e 100644
--- a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py
+++ b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service.py
@@ -23,6 +23,7 @@
from unittest import mock
from unittest.mock import MagicMock, PropertyMock
+import pytest
from googleapiclient.errors import HttpError
from parameterized import parameterized
@@ -142,7 +143,7 @@ def test_pass_name_on_create_job(
body = _with_name(TEST_BODY, TEST_CLEAR_JOB_NAME)
get_conn.side_effect = HttpError(GCPRequestMock(), TEST_HTTP_ERR_CONTENT)
- with self.assertRaises(HttpError):
+ with pytest.raises(HttpError):
# check status DELETED generates new job name
get_transfer_job.return_value = TEST_RESULT_STATUS_DELETED
@@ -153,7 +154,7 @@ def test_pass_name_on_create_job(
enable_transfer_job.return_value = TEST_RESULT_STATUS_ENABLED
res = self.gct_hook.create_transfer_job(body=body)
- self.assertEqual(res, TEST_RESULT_STATUS_ENABLED)
+ assert res == TEST_RESULT_STATUS_ENABLED
class TestJobNames(unittest.TestCase):
@@ -162,7 +163,7 @@ def setUp(self) -> None:
def test_new_suffix(self):
for job_name in ["jobNames/new_job", "jobNames/new_job_h", "jobNames/newJob"]:
- self.assertIsNotNone(self.re_suffix.match(gen_job_name(job_name).split("_")[-1]))
+ assert self.re_suffix.match(gen_job_name(job_name).split("_")[-1]) is not None
class TestGCPTransferServiceHookWithPassedProjectId(unittest.TestCase):
@@ -183,8 +184,8 @@ def test_gct_client_creation(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'storagetransfer', 'v1', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(mock_build.return_value, result)
- self.assertEqual(self.gct_hook._conn, result)
+ assert mock_build.return_value == result
+ assert self.gct_hook._conn == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -200,7 +201,7 @@ def test_create_transfer_job(self, get_conn, mock_project_id):
execute_method = create_method.return_value.execute
execute_method.return_value = TEST_TRANSFER_JOB
res = self.gct_hook.create_transfer_job(body=TEST_BODY)
- self.assertEqual(res, TEST_TRANSFER_JOB)
+ assert res == TEST_TRANSFER_JOB
create_method.assert_called_once_with(body=TEST_BODY)
execute_method.assert_called_once_with(num_retries=5)
@@ -213,8 +214,8 @@ def test_get_transfer_job(self, get_conn):
execute_method = get_method.return_value.execute
execute_method.return_value = TEST_TRANSFER_JOB
res = self.gct_hook.get_transfer_job(job_name=TEST_TRANSFER_JOB_NAME, project_id=TEST_PROJECT_ID)
- self.assertIsNotNone(res)
- self.assertEqual(TEST_TRANSFER_JOB_NAME, res[NAME])
+ assert res is not None
+ assert TEST_TRANSFER_JOB_NAME == res[NAME]
get_method.assert_called_once_with(jobName=TEST_TRANSFER_JOB_NAME, projectId=TEST_PROJECT_ID)
execute_method.assert_called_once_with(num_retries=5)
@@ -236,14 +237,14 @@ def test_list_transfer_job(self, get_conn, mock_project_id):
list_next.return_value = None
res = self.gct_hook.list_transfer_job(request_filter=TEST_TRANSFER_JOB_FILTER)
- self.assertIsNotNone(res)
- self.assertEqual(res, [TEST_TRANSFER_JOB])
+ assert res is not None
+ assert res == [TEST_TRANSFER_JOB]
list_method.assert_called_once_with(filter=mock.ANY)
args, kwargs = list_method.call_args_list[0]
- self.assertEqual(
- json.loads(kwargs['filter']),
- {FILTER_PROJECT_ID: TEST_PROJECT_ID, FILTER_JOB_NAMES: [TEST_TRANSFER_JOB_NAME]},
- )
+ assert json.loads(kwargs['filter']) == {
+ FILTER_PROJECT_ID: TEST_PROJECT_ID,
+ FILTER_JOB_NAMES: [TEST_TRANSFER_JOB_NAME],
+ }
list_execute_method.assert_called_once_with(num_retries=5)
@mock.patch(
@@ -262,7 +263,7 @@ def test_update_transfer_job(self, get_conn, mock_project_id):
res = self.gct_hook.update_transfer_job(
job_name=TEST_TRANSFER_JOB_NAME, body=TEST_UPDATE_TRANSFER_JOB_BODY
)
- self.assertIsNotNone(res)
+ assert res is not None
update_method.assert_called_once_with(
jobName=TEST_TRANSFER_JOB_NAME, body=TEST_UPDATE_TRANSFER_JOB_BODY
)
@@ -309,7 +310,7 @@ def test_get_transfer_operation(self, get_conn):
execute_method = get_method.return_value.execute
execute_method.return_value = TEST_TRANSFER_OPERATION
res = self.gct_hook.get_transfer_operation(operation_name=TEST_TRANSFER_OPERATION_NAME)
- self.assertEqual(res, TEST_TRANSFER_OPERATION)
+ assert res == TEST_TRANSFER_OPERATION
get_method.assert_called_once_with(name=TEST_TRANSFER_OPERATION_NAME)
execute_method.assert_called_once_with(num_retries=5)
@@ -331,14 +332,14 @@ def test_list_transfer_operation(self, get_conn, mock_project_id):
list_next.return_value = None
res = self.gct_hook.list_transfer_operations(request_filter=TEST_TRANSFER_OPERATION_FILTER)
- self.assertIsNotNone(res)
- self.assertEqual(res, [TEST_TRANSFER_OPERATION])
+ assert res is not None
+ assert res == [TEST_TRANSFER_OPERATION]
list_method.assert_called_once_with(filter=mock.ANY, name='transferOperations')
args, kwargs = list_method.call_args_list[0]
- self.assertEqual(
- json.loads(kwargs['filter']),
- {FILTER_PROJECT_ID: TEST_PROJECT_ID, FILTER_JOB_NAMES: [TEST_TRANSFER_JOB_NAME]},
- )
+ assert json.loads(kwargs['filter']) == {
+ FILTER_PROJECT_ID: TEST_PROJECT_ID,
+ FILTER_JOB_NAMES: [TEST_TRANSFER_JOB_NAME],
+ }
list_execute_method.assert_called_once_with(num_retries=5)
@mock.patch(
@@ -435,9 +436,9 @@ def test_wait_for_transfer_job_failed(self, mock_get_conn, mock_sleep, mock_proj
mock_get_conn.return_value.transferOperations.return_value.list_next.return_value = None
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.gct_hook.wait_for_transfer_job({PROJECT_ID: TEST_PROJECT_ID, NAME: 'transferJobs/test-job'})
- self.assertTrue(list_method.called)
+ assert list_method.called
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -467,8 +468,8 @@ def test_wait_for_transfer_job_expect_failed(
}
get_conn.return_value.transferOperations.return_value.list_next.return_value = None
- with self.assertRaisesRegex(
- AirflowException, "An unexpected operation status was encountered. Expected: SUCCESS"
+ with pytest.raises(
+ AirflowException, match="An unexpected operation status was encountered. Expected: SUCCESS"
):
self.gct_hook.wait_for_transfer_job(
job={PROJECT_ID: 'test-project', NAME: 'transferJobs/test-job'},
@@ -500,9 +501,9 @@ def test_wait_for_transfer_job_expect_failed(
def test_operations_contain_expected_statuses_red_path(self, statuses, expected_statuses):
operations = [{NAME: TEST_TRANSFER_OPERATION_NAME, METADATA: {STATUS: status}} for status in statuses]
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowException,
- "An unexpected operation status was encountered. Expected: {}".format(
+ match="An unexpected operation status was encountered. Expected: {}".format(
", ".join(expected_statuses)
),
):
@@ -542,7 +543,7 @@ def test_operations_contain_expected_statuses_green_path(self, statuses, expecte
operations, expected_statuses
)
- self.assertTrue(result)
+ assert result
class TestGCPTransferServiceHookWithProjectIdFromConnection(unittest.TestCase):
@@ -563,8 +564,8 @@ def test_gct_client_creation(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'storagetransfer', 'v1', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(mock_build.return_value, result)
- self.assertEqual(self.gct_hook._conn, result)
+ assert mock_build.return_value == result
+ assert self.gct_hook._conn == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -580,7 +581,7 @@ def test_create_transfer_job(self, get_conn, mock_project_id):
execute_method = create_method.return_value.execute
execute_method.return_value = deepcopy(TEST_TRANSFER_JOB)
res = self.gct_hook.create_transfer_job(body=self._without_project_id(TEST_BODY))
- self.assertEqual(res, TEST_TRANSFER_JOB)
+ assert res == TEST_TRANSFER_JOB
create_method.assert_called_once_with(body=self._with_project_id(TEST_BODY, 'example-project'))
execute_method.assert_called_once_with(num_retries=5)
@@ -600,8 +601,8 @@ def test_get_transfer_job(self, get_conn, mock_project_id):
res = self.gct_hook.get_transfer_job( # pylint: disable=no-value-for-parameter
job_name=TEST_TRANSFER_JOB_NAME
)
- self.assertIsNotNone(res)
- self.assertEqual(TEST_TRANSFER_JOB_NAME, res[NAME])
+ assert res is not None
+ assert TEST_TRANSFER_JOB_NAME == res[NAME]
get_method.assert_called_once_with(jobName=TEST_TRANSFER_JOB_NAME, projectId='example-project')
execute_method.assert_called_once_with(num_retries=5)
@@ -625,15 +626,15 @@ def test_list_transfer_job(self, get_conn, mock_project_id):
res = self.gct_hook.list_transfer_job(
request_filter=_without_key(TEST_TRANSFER_JOB_FILTER, FILTER_PROJECT_ID)
)
- self.assertIsNotNone(res)
- self.assertEqual(res, [TEST_TRANSFER_JOB])
+ assert res is not None
+ assert res == [TEST_TRANSFER_JOB]
list_method.assert_called_once_with(filter=mock.ANY)
args, kwargs = list_method.call_args_list[0]
- self.assertEqual(
- json.loads(kwargs['filter']),
- {FILTER_PROJECT_ID: 'example-project', FILTER_JOB_NAMES: [TEST_TRANSFER_JOB_NAME]},
- )
+ assert json.loads(kwargs['filter']) == {
+ FILTER_PROJECT_ID: 'example-project',
+ FILTER_JOB_NAMES: [TEST_TRANSFER_JOB_NAME],
+ }
list_execute_method.assert_called_once_with(num_retries=5)
@mock.patch(
@@ -652,7 +653,7 @@ def test_update_transfer_job(self, get_conn, mock_project_id):
res = self.gct_hook.update_transfer_job(
job_name=TEST_TRANSFER_JOB_NAME, body=self._without_project_id(TEST_UPDATE_TRANSFER_JOB_BODY)
)
- self.assertIsNotNone(res)
+ assert res is not None
update_method.assert_called_once_with(
jobName=TEST_TRANSFER_JOB_NAME,
body=self._with_project_id(TEST_UPDATE_TRANSFER_JOB_BODY, 'example-project'),
@@ -700,7 +701,7 @@ def test_get_transfer_operation(self, get_conn):
execute_method = get_method.return_value.execute
execute_method.return_value = TEST_TRANSFER_OPERATION
res = self.gct_hook.get_transfer_operation(operation_name=TEST_TRANSFER_OPERATION_NAME)
- self.assertEqual(res, TEST_TRANSFER_OPERATION)
+ assert res == TEST_TRANSFER_OPERATION
get_method.assert_called_once_with(name=TEST_TRANSFER_OPERATION_NAME)
execute_method.assert_called_once_with(num_retries=5)
@@ -724,14 +725,14 @@ def test_list_transfer_operation(self, get_conn, mock_project_id):
res = self.gct_hook.list_transfer_operations(
request_filter=_without_key(TEST_TRANSFER_OPERATION_FILTER, FILTER_PROJECT_ID)
)
- self.assertIsNotNone(res)
- self.assertEqual(res, [TEST_TRANSFER_OPERATION])
+ assert res is not None
+ assert res == [TEST_TRANSFER_OPERATION]
list_method.assert_called_once_with(filter=mock.ANY, name='transferOperations')
args, kwargs = list_method.call_args_list[0]
- self.assertEqual(
- json.loads(kwargs['filter']),
- {FILTER_PROJECT_ID: 'example-project', FILTER_JOB_NAMES: [TEST_TRANSFER_JOB_NAME]},
- )
+ assert json.loads(kwargs['filter']) == {
+ FILTER_PROJECT_ID: 'example-project',
+ FILTER_JOB_NAMES: [TEST_TRANSFER_JOB_NAME],
+ }
list_execute_method.assert_called_once_with(num_retries=5)
@staticmethod
@@ -766,8 +767,8 @@ def test_gct_client_creation(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'storagetransfer', 'v1', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(mock_build.return_value, result)
- self.assertEqual(self.gct_hook._conn, result)
+ assert mock_build.return_value == result
+ assert self.gct_hook._conn == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -782,14 +783,13 @@ def test_create_transfer_job(self, get_conn, mock_project_id):
create_method = get_conn.return_value.transferJobs.return_value.create
execute_method = create_method.return_value.execute
execute_method.return_value = deepcopy(TEST_TRANSFER_JOB)
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.gct_hook.create_transfer_job(body=_without_key(TEST_BODY, PROJECT_ID))
- self.assertEqual(
+ assert (
'The project id must be passed either as `projectId` key in `body` '
'parameter or as project_id '
- 'extra in Google Cloud connection definition. Both are not set!',
- str(e.exception),
+ 'extra in Google Cloud connection definition. Both are not set!' == str(ctx.value)
)
@mock.patch(
@@ -805,15 +805,14 @@ def test_get_transfer_job(self, get_conn, mock_project_id):
get_method = get_conn.return_value.transferJobs.return_value.get
execute_method = get_method.return_value.execute
execute_method.return_value = TEST_TRANSFER_JOB
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.gct_hook.get_transfer_job( # pylint: disable=no-value-for-parameter
job_name=TEST_TRANSFER_JOB_NAME
)
- self.assertEqual(
+ assert (
'The project id must be passed either as keyword project_id '
'parameter or as project_id extra in Google Cloud connection definition. '
- 'Both are not set!',
- str(e.exception),
+ 'Both are not set!' == str(ctx.value)
)
@mock.patch(
@@ -833,15 +832,14 @@ def test_list_transfer_job(self, get_conn, mock_project_id):
list_next = get_conn.return_value.transferJobs.return_value.list_next
list_next.return_value = None
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.gct_hook.list_transfer_job(
request_filter=_without_key(TEST_TRANSFER_JOB_FILTER, FILTER_PROJECT_ID)
)
- self.assertEqual(
+ assert (
'The project id must be passed either as `project_id` key in `filter` parameter or as '
- 'project_id extra in Google Cloud connection definition. Both are not set!',
- str(e.exception),
+ 'project_id extra in Google Cloud connection definition. Both are not set!' == str(ctx.value)
)
@mock.patch(
@@ -864,7 +862,7 @@ def test_list_transfer_operation_multiple_page(self, get_conn, mock_project_id):
get_conn.return_value.transferOperations.return_value = transfer_operation_mock
res = self.gct_hook.list_transfer_operations(request_filter=TEST_TRANSFER_OPERATION_FILTER)
- self.assertEqual(res, [TEST_TRANSFER_OPERATION] * 4)
+ assert res == [TEST_TRANSFER_OPERATION] * 4
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -879,15 +877,14 @@ def test_update_transfer_job(self, get_conn, mock_project_id):
update_method = get_conn.return_value.transferJobs.return_value.patch
execute_method = update_method.return_value.execute
execute_method.return_value = TEST_TRANSFER_JOB
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.gct_hook.update_transfer_job(
job_name=TEST_TRANSFER_JOB_NAME, body=_without_key(TEST_UPDATE_TRANSFER_JOB_BODY, PROJECT_ID)
)
- self.assertEqual(
+ assert (
'The project id must be passed either as `projectId` key in `body` parameter or as project_id '
- 'extra in Google Cloud connection definition. Both are not set!',
- str(e.exception),
+ 'extra in Google Cloud connection definition. Both are not set!' == str(ctx.value)
)
@mock.patch(
@@ -900,15 +897,14 @@ def test_update_transfer_job(self, get_conn, mock_project_id):
'.CloudDataTransferServiceHook.get_conn'
)
def test_delete_transfer_job(self, get_conn, mock_project_id): # pylint: disable=unused-argument
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.gct_hook.delete_transfer_job( # pylint: disable=no-value-for-parameter
job_name=TEST_TRANSFER_JOB_NAME
)
- self.assertEqual(
+ assert (
'The project id must be passed either as keyword project_id parameter or as project_id extra in '
- 'Google Cloud connection definition. Both are not set!',
- str(e.exception),
+ 'Google Cloud connection definition. Both are not set!' == str(ctx.value)
)
@mock.patch(
@@ -928,13 +924,12 @@ def test_list_transfer_operation(self, get_conn, mock_project_id):
list_next = get_conn.return_value.transferOperations.return_value.list_next
list_next.return_value = None
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.gct_hook.list_transfer_operations(
request_filter=_without_key(TEST_TRANSFER_OPERATION_FILTER, FILTER_PROJECT_ID)
)
- self.assertEqual(
+ assert (
'The project id must be passed either as `project_id` key in `filter` parameter or as project_id '
- 'extra in Google Cloud connection definition. Both are not set!',
- str(e.exception),
+ 'extra in Google Cloud connection definition. Both are not set!' == str(ctx.value)
)
diff --git a/tests/providers/google/cloud/hooks/test_compute.py b/tests/providers/google/cloud/hooks/test_compute.py
index 64f6d928fcf60..94692b1ed2ab7 100644
--- a/tests/providers/google/cloud/hooks/test_compute.py
+++ b/tests/providers/google/cloud/hooks/test_compute.py
@@ -22,6 +22,8 @@
from unittest import mock
from unittest.mock import PropertyMock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook, GceOperationStatus
from tests.providers.google.cloud.utils.base_gcp_mock import (
@@ -52,7 +54,7 @@ def test_gce_client_creation(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'compute', 'v1', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(mock_build.return_value, result)
+ assert mock_build.return_value == result
@mock.patch('airflow.providers.google.cloud.hooks.compute.ComputeEngineHook.get_conn')
@mock.patch(
@@ -66,7 +68,7 @@ def test_start_instance_overridden_project_id(self, wait_for_operation_to_comple
res = self.gce_hook_no_project_id.start_instance(
project_id='example-project', zone=GCE_ZONE, resource_id=GCE_INSTANCE
)
- self.assertIsNone(res)
+ assert res is None
start_method.assert_called_once_with(instance='instance', project='example-project', zone='zone')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(
@@ -85,7 +87,7 @@ def test_stop_instance_overridden_project_id(self, wait_for_operation_to_complet
res = self.gce_hook_no_project_id.stop_instance(
project_id='example-project', zone=GCE_ZONE, resource_id=GCE_INSTANCE
)
- self.assertIsNone(res)
+ assert res is None
stop_method.assert_called_once_with(instance='instance', project='example-project', zone='zone')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(
@@ -104,7 +106,7 @@ def test_set_machine_type_overridden_project_id(self, wait_for_operation_to_comp
res = self.gce_hook_no_project_id.set_machine_type(
body={}, project_id='example-project', zone=GCE_ZONE, resource_id=GCE_INSTANCE
)
- self.assertIsNone(res)
+ assert res is None
set_machine_type_method.assert_called_once_with(
body={}, instance='instance', project='example-project', zone='zone'
)
@@ -125,7 +127,7 @@ def test_get_instance_template_overridden_project_id(self, wait_for_operation_to
res = self.gce_hook_no_project_id.get_instance_template(
resource_id=GCE_INSTANCE_TEMPLATE, project_id='example-project'
)
- self.assertIsNotNone(res)
+ assert res is not None
get_method.assert_called_once_with(instanceTemplate='instance-template', project='example-project')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_not_called()
@@ -142,7 +144,7 @@ def test_insert_instance_template_overridden_project_id(self, wait_for_operation
res = self.gce_hook_no_project_id.insert_instance_template(
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, body={}, request_id=GCE_REQUEST_ID
)
- self.assertIsNone(res)
+ assert res is None
insert_method.assert_called_once_with(body={}, project='example-project', requestId='request_id')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(
@@ -161,7 +163,7 @@ def test_get_instance_group_manager_overridden_project_id(self, wait_for_operati
res = self.gce_hook_no_project_id.get_instance_group_manager(
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER
)
- self.assertIsNotNone(res)
+ assert res is not None
get_method.assert_called_once_with(
instanceGroupManager='instance_group_manager', project='example-project', zone='zone'
)
@@ -186,7 +188,7 @@ def test_patch_instance_group_manager_overridden_project_id(
body={},
request_id=GCE_REQUEST_ID,
)
- self.assertIsNone(res)
+ assert res is None
patch_method.assert_called_once_with(
body={},
instanceGroupManager='instance_group_manager',
@@ -227,7 +229,7 @@ def test_start_instance(self, wait_for_operation_to_complete, get_conn, mock_pro
resource_id=GCE_INSTANCE,
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
)
- self.assertIsNone(res)
+ assert res is None
start_method.assert_called_once_with(instance='instance', project='example-project', zone='zone')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(
@@ -244,7 +246,7 @@ def test_start_instance_overridden_project_id(self, wait_for_operation_to_comple
execute_method.return_value = {"name": "operation_id"}
wait_for_operation_to_complete.return_value = None
res = self.gce_hook.start_instance(project_id='new-project', zone=GCE_ZONE, resource_id=GCE_INSTANCE)
- self.assertIsNone(res)
+ assert res is None
start_method.assert_called_once_with(instance='instance', project='new-project', zone='zone')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(
@@ -270,7 +272,7 @@ def test_stop_instance(self, wait_for_operation_to_complete, get_conn, mock_proj
resource_id=GCE_INSTANCE,
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
)
- self.assertIsNone(res)
+ assert res is None
stop_method.assert_called_once_with(instance='instance', project='example-project', zone='zone')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(
@@ -287,7 +289,7 @@ def test_stop_instance_overridden_project_id(self, wait_for_operation_to_complet
execute_method.return_value = {"name": "operation_id"}
wait_for_operation_to_complete.return_value = None
res = self.gce_hook.stop_instance(project_id='new-project', zone=GCE_ZONE, resource_id=GCE_INSTANCE)
- self.assertIsNone(res)
+ assert res is None
stop_method.assert_called_once_with(instance='instance', project='new-project', zone='zone')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(
@@ -313,7 +315,7 @@ def test_set_machine_type_instance(self, wait_for_operation_to_complete, get_con
resource_id=GCE_INSTANCE,
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
)
- self.assertIsNone(res)
+ assert res is None
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(
project_id='example-project', operation_name='operation_id', zone='zone'
@@ -330,7 +332,7 @@ def test_set_machine_type_instance_overridden_project_id(self, wait_for_operatio
res = self.gce_hook.set_machine_type(
project_id='new-project', body={}, zone=GCE_ZONE, resource_id=GCE_INSTANCE
)
- self.assertIsNone(res)
+ assert res is None
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(
project_id='new-project', operation_name='operation_id', zone='zone'
@@ -353,7 +355,7 @@ def test_get_instance_template(self, wait_for_operation_to_complete, get_conn, m
res = self.gce_hook.get_instance_template(
resource_id=GCE_INSTANCE_TEMPLATE, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST
)
- self.assertIsNotNone(res)
+ assert res is not None
get_method.assert_called_once_with(instanceTemplate='instance-template', project='example-project')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_not_called()
@@ -368,7 +370,7 @@ def test_get_instance_template_overridden_project_id(self, wait_for_operation_to
execute_method.return_value = {"name": "operation_id"}
wait_for_operation_to_complete.return_value = None
res = self.gce_hook.get_instance_template(project_id='new-project', resource_id=GCE_INSTANCE_TEMPLATE)
- self.assertIsNotNone(res)
+ assert res is not None
get_method.assert_called_once_with(instanceTemplate='instance-template', project='new-project')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_not_called()
@@ -392,7 +394,7 @@ def test_insert_instance_template(self, wait_for_operation_to_complete, get_conn
request_id=GCE_REQUEST_ID,
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
)
- self.assertIsNone(res)
+ assert res is None
insert_method.assert_called_once_with(body={}, project='example-project', requestId='request_id')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(
@@ -411,7 +413,7 @@ def test_insert_instance_template_overridden_project_id(self, wait_for_operation
res = self.gce_hook.insert_instance_template(
project_id='new-project', body={}, request_id=GCE_REQUEST_ID
)
- self.assertIsNone(res)
+ assert res is None
insert_method.assert_called_once_with(body={}, project='new-project', requestId='request_id')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(
@@ -437,7 +439,7 @@ def test_get_instance_group_manager(self, wait_for_operation_to_complete, get_co
resource_id=GCE_INSTANCE_GROUP_MANAGER,
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
)
- self.assertIsNotNone(res)
+ assert res is not None
get_method.assert_called_once_with(
instanceGroupManager='instance_group_manager', project='example-project', zone='zone'
)
@@ -456,7 +458,7 @@ def test_get_instance_group_manager_overridden_project_id(self, wait_for_operati
res = self.gce_hook.get_instance_group_manager(
project_id='new-project', zone=GCE_ZONE, resource_id=GCE_INSTANCE_GROUP_MANAGER
)
- self.assertIsNotNone(res)
+ assert res is not None
get_method.assert_called_once_with(
instanceGroupManager='instance_group_manager', project='new-project', zone='zone'
)
@@ -484,7 +486,7 @@ def test_patch_instance_group_manager(self, wait_for_operation_to_complete, get_
request_id=GCE_REQUEST_ID,
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
)
- self.assertIsNone(res)
+ assert res is None
patch_method.assert_called_once_with(
body={},
instanceGroupManager='instance_group_manager',
@@ -515,7 +517,7 @@ def test_patch_instance_group_manager_overridden_project_id(
body={},
request_id=GCE_REQUEST_ID,
)
- self.assertIsNone(res)
+ assert res is None
patch_method.assert_called_once_with(
body={},
instanceGroupManager='instance_group_manager',
@@ -567,7 +569,7 @@ def test_wait_for_operation_to_complete_no_zone_error(self, mock_operation_statu
'httpErrorMessage': 'sample msg',
}
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.gce_hook._wait_for_operation_to_complete(
project_id=project_id, operation_name=operation_name, zone=None
)
diff --git a/tests/providers/google/cloud/hooks/test_compute_ssh.py b/tests/providers/google/cloud/hooks/test_compute_ssh.py
index 8de3038c5f8f1..785aece8a88c6 100644
--- a/tests/providers/google/cloud/hooks/test_compute_ssh.py
+++ b/tests/providers/google/cloud/hooks/test_compute_ssh.py
@@ -53,7 +53,7 @@ def test_get_conn_default_configuration(
hook = ComputeEngineSSHHook(instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE)
result = hook.get_conn()
- self.assertEqual(mock_ssh_client.return_value, result)
+ assert mock_ssh_client.return_value == result
mock_paramiko.RSAKey.generate.assert_called_once_with(2048)
mock_compute_hook.assert_has_calls(
@@ -112,7 +112,7 @@ def test_get_conn_authorize_using_instance_metadata(
hook = ComputeEngineSSHHook(instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE, use_oslogin=False)
result = hook.get_conn()
- self.assertEqual(mock_ssh_client.return_value, result)
+ assert mock_ssh_client.return_value == result
mock_paramiko.RSAKey.generate.assert_called_once_with(2048)
mock_compute_hook.assert_has_calls(
@@ -173,7 +173,7 @@ def test_get_conn_authorize_using_instance_metadata_append_ssh_keys(
hook = ComputeEngineSSHHook(instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE, use_oslogin=False)
result = hook.get_conn()
- self.assertEqual(mock_ssh_client.return_value, result)
+ assert mock_ssh_client.return_value == result
mock_compute_hook.return_value.set_instance_metadata.assert_called_once_with(
metadata={"items": [{"key": "ssh-keys", "value": f"{TEST_PUB_KEY}\n{TEST_PUB_KEY2}\n"}]},
@@ -201,7 +201,7 @@ def test_get_conn_private_ip(self, mock_ssh_client, mock_paramiko, mock_os_login
instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE, use_oslogin=False, use_internal_ip=True
)
result = hook.get_conn()
- self.assertEqual(mock_ssh_client.return_value, result)
+ assert mock_ssh_client.return_value == result
mock_compute_hook.return_value.get_instance_address.assert_called_once_with(
project_id=TEST_PROJECT_ID, resource_id=TEST_INSTANCE_NAME, use_internal_ip=True, zone=TEST_ZONE
@@ -227,7 +227,7 @@ def test_get_conn_custom_hostname(
hostname="custom-hostname",
)
result = hook.get_conn()
- self.assertEqual(mock_ssh_client.return_value, result)
+ assert mock_ssh_client.return_value == result
mock_compute_hook.return_value.get_instance_address.assert_not_called()
mock_ssh_client.return_value.connect.assert_called_once_with(
@@ -252,7 +252,7 @@ def test_get_conn_iap_tunnel(self, mock_ssh_client, mock_paramiko, mock_os_login
instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE, use_oslogin=False, use_iap_tunnel=True
)
result = hook.get_conn()
- self.assertEqual(mock_ssh_client.return_value, result)
+ assert mock_ssh_client.return_value == result
mock_ssh_client.return_value.connect.assert_called_once_with(
hostname=mock.ANY,
@@ -286,7 +286,7 @@ class CustomException(Exception):
hook = ComputeEngineSSHHook(instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE)
hook.get_conn()
- self.assertEqual(3, mock_ssh_client.return_value.connect.call_count)
+ assert 3 == mock_ssh_client.return_value.connect.call_count
def test_read_configuration_from_connection(self):
conn = Connection(
@@ -308,17 +308,17 @@ def test_read_configuration_from_connection(self):
with mock.patch.dict("os.environ", AIRFLOW_CONN_GCPSSH=conn_uri):
hook = ComputeEngineSSHHook(gcp_conn_id="gcpssh")
hook._load_connection_config()
- self.assertEqual("conn-instance-name", hook.instance_name)
- self.assertEqual("conn-host", hook.hostname)
- self.assertEqual("conn-user", hook.user)
- self.assertEqual(True, hook.use_internal_ip)
- self.assertIsInstance(hook.use_internal_ip, bool)
- self.assertEqual(True, hook.use_iap_tunnel)
- self.assertIsInstance(hook.use_iap_tunnel, bool)
- self.assertEqual(False, hook.use_oslogin)
- self.assertIsInstance(hook.use_oslogin, bool)
- self.assertEqual(4242, hook.expire_time)
- self.assertIsInstance(hook.expire_time, int)
+ assert "conn-instance-name" == hook.instance_name
+ assert "conn-host" == hook.hostname
+ assert "conn-user" == hook.user
+ assert hook.use_internal_ip is True
+ assert isinstance(hook.use_internal_ip, bool)
+ assert hook.use_iap_tunnel is True
+ assert isinstance(hook.use_iap_tunnel, bool)
+ assert hook.use_oslogin is False
+ assert isinstance(hook.use_oslogin, bool)
+ assert 4242 == hook.expire_time
+ assert isinstance(hook.expire_time, int)
def test_read_configuration_from_connection_empty_config(self):
conn = Connection(
@@ -329,14 +329,14 @@ def test_read_configuration_from_connection_empty_config(self):
with mock.patch.dict("os.environ", AIRFLOW_CONN_GCPSSH=conn_uri):
hook = ComputeEngineSSHHook(gcp_conn_id="gcpssh")
hook._load_connection_config()
- self.assertEqual(None, hook.instance_name)
- self.assertEqual(None, hook.hostname)
- self.assertEqual("root", hook.user)
- self.assertEqual(False, hook.use_internal_ip)
- self.assertIsInstance(hook.use_internal_ip, bool)
- self.assertEqual(False, hook.use_iap_tunnel)
- self.assertIsInstance(hook.use_iap_tunnel, bool)
- self.assertEqual(False, hook.use_oslogin)
- self.assertIsInstance(hook.use_oslogin, bool)
- self.assertEqual(300, hook.expire_time)
- self.assertIsInstance(hook.expire_time, int)
+ assert None is hook.instance_name
+ assert None is hook.hostname
+ assert "root" == hook.user
+ assert False is hook.use_internal_ip
+ assert isinstance(hook.use_internal_ip, bool)
+ assert False is hook.use_iap_tunnel
+ assert isinstance(hook.use_iap_tunnel, bool)
+ assert False is hook.use_oslogin
+ assert isinstance(hook.use_oslogin, bool)
+ assert 300 == hook.expire_time
+ assert isinstance(hook.expire_time, int)
diff --git a/tests/providers/google/cloud/hooks/test_datacatalog.py b/tests/providers/google/cloud/hooks/test_datacatalog.py
index 02f0002e2a07a..99d785fa61637 100644
--- a/tests/providers/google/cloud/hooks/test_datacatalog.py
+++ b/tests/providers/google/cloud/hooks/test_datacatalog.py
@@ -20,6 +20,7 @@
from typing import Dict, Sequence, Tuple
from unittest import TestCase, mock
+import pytest
from google.api_core.retry import Retry
from google.cloud.datacatalog_v1beta1 import CreateTagRequest, CreateTagTemplateRequest
from google.cloud.datacatalog_v1beta1.types import Entry, Tag, TagTemplate
@@ -130,8 +131,8 @@ def test_lookup_entry_with_sql_resource(self, mock_get_conn, mock_get_creds_and_
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_lookup_entry_without_resource(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(
- AirflowException, re.escape("At least one of linked_resource, sql_resource should be set.")
+ with pytest.raises(
+ AirflowException, match=re.escape("At least one of linked_resource, sql_resource should be set.")
):
self.hook.lookup_entry(retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA)
@@ -561,7 +562,7 @@ def test_get_tag_for_template_name(self, mock_get_conn, mock_get_creds_and_proje
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
- self.assertEqual(result, tag_2)
+ assert result == tag_2
@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id",
@@ -1089,7 +1090,7 @@ def test_get_tag_for_template_name(self, mock_get_conn, mock_get_creds_and_proje
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
- self.assertEqual(result, tag_2)
+ assert result == tag_2
@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id",
@@ -1245,7 +1246,7 @@ def setUp(
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_create_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.create_entry( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
entry_group=TEST_ENTRY_GROUP_ID,
@@ -1262,7 +1263,7 @@ def test_create_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> Non
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_create_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.create_entry_group( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
entry_group_id=TEST_ENTRY_GROUP_ID,
@@ -1278,7 +1279,7 @@ def test_create_entry_group(self, mock_get_conn, mock_get_creds_and_project_id)
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_create_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.create_tag( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
@@ -1297,7 +1298,7 @@ def test_create_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_create_tag_protobuff(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.create_tag( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
@@ -1316,7 +1317,7 @@ def test_create_tag_protobuff(self, mock_get_conn, mock_get_creds_and_project_id
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_create_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.create_tag_template( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
@@ -1333,7 +1334,7 @@ def test_create_tag_template(self, mock_get_conn, mock_get_creds_and_project_id)
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_create_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.create_tag_template_field( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
@@ -1351,7 +1352,7 @@ def test_create_tag_template_field(self, mock_get_conn, mock_get_creds_and_proje
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_delete_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.delete_entry( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
@@ -1368,7 +1369,7 @@ def test_delete_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> Non
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_delete_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.delete_entry_group( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
entry_group=TEST_ENTRY_GROUP_ID,
@@ -1383,7 +1384,7 @@ def test_delete_entry_group(self, mock_get_conn, mock_get_creds_and_project_id)
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_delete_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.delete_tag( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
entry_group=TEST_ENTRY_GROUP_ID,
@@ -1400,7 +1401,7 @@ def test_delete_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_delete_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.delete_tag_template( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
tag_template=TEST_TAG_TEMPLATE_ID,
@@ -1416,7 +1417,7 @@ def test_delete_tag_template(self, mock_get_conn, mock_get_creds_and_project_id)
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_delete_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.delete_tag_template_field( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
tag_template=TEST_TAG_TEMPLATE_ID,
@@ -1433,7 +1434,7 @@ def test_delete_tag_template_field(self, mock_get_conn, mock_get_creds_and_proje
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_get_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.get_entry( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
entry_group=TEST_ENTRY_GROUP_ID,
@@ -1449,7 +1450,7 @@ def test_get_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_get_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.get_entry_group( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
entry_group=TEST_ENTRY_GROUP_ID,
@@ -1465,7 +1466,7 @@ def test_get_entry_group(self, mock_get_conn, mock_get_creds_and_project_id) ->
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_get_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.get_tag_template( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
tag_template=TEST_TAG_TEMPLATE_ID,
@@ -1480,7 +1481,7 @@ def test_get_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) ->
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_list_tags(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.list_tags( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
entry_group=TEST_ENTRY_GROUP_ID,
@@ -1501,7 +1502,7 @@ def test_get_tag_for_template_name(self, mock_get_conn, mock_get_creds_and_proje
tag_2 = mock.MagicMock(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2))
mock_get_conn.return_value.list_tags.return_value = [tag_1, tag_2]
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.get_tag_for_template_name( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
entry_group=TEST_ENTRY_GROUP_ID,
@@ -1518,7 +1519,7 @@ def test_get_tag_for_template_name(self, mock_get_conn, mock_get_creds_and_proje
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_rename_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.rename_tag_template_field( # pylint: disable=no-value-for-parameter
location=TEST_LOCATION,
tag_template=TEST_TAG_TEMPLATE_ID,
@@ -1535,7 +1536,7 @@ def test_rename_tag_template_field(self, mock_get_conn, mock_get_creds_and_proje
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_update_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.update_entry( # pylint: disable=no-value-for-parameter
entry=TEST_ENTRY,
update_mask=TEST_UPDATE_MASK,
@@ -1553,7 +1554,7 @@ def test_update_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> Non
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_update_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.update_tag( # pylint: disable=no-value-for-parameter
tag=deepcopy(TEST_TAG),
update_mask=TEST_UPDATE_MASK,
@@ -1572,7 +1573,7 @@ def test_update_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_update_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.update_tag_template( # pylint: disable=no-value-for-parameter
tag_template=TEST_TAG_TEMPLATE,
update_mask=TEST_UPDATE_MASK,
@@ -1589,7 +1590,7 @@ def test_update_tag_template(self, mock_get_conn, mock_get_creds_and_project_id)
)
@mock.patch("airflow.providers.google.cloud.hooks.datacatalog.CloudDataCatalogHook.get_conn")
def test_update_tag_template_field(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.update_tag_template_field( # pylint: disable=no-value-for-parameter
tag_template_field=TEST_TAG_TEMPLATE_FIELD,
update_mask=TEST_UPDATE_MASK,
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py
index 3e6e759448167..5297b307fb76b 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -26,6 +26,7 @@
from unittest.mock import MagicMock
from uuid import UUID
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -148,10 +149,10 @@ class FixtureFallback:
def test_fn(self, *args, **kwargs):
mock_instance(*args, **kwargs)
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowException,
- "The mutually exclusive parameter `project_id` and `project` key in `variables` parameter are "
- "both present\\. Please remove one\\.",
+ match="The mutually exclusive parameter `project_id` and `project` key in `variables` parameter "
+ "are both present\\. Please remove one\\.",
):
FixtureFallback().test_fn(variables={'project': "TEST"}, project_id="TEST2")
@@ -163,8 +164,8 @@ class FixutureFallback:
def test_fn(self, *args, **kwargs):
mock_instance(*args, **kwargs)
- with self.assertRaisesRegex(
- AirflowException, "You must use keyword arguments in this methods rather than positional"
+ with pytest.raises(
+ AirflowException, match="You must use keyword arguments in this methods rather than positional"
):
FixutureFallback().test_fn({'project': "TEST"}, "TEST2")
@@ -190,7 +191,7 @@ def test_dataflow_client_creation(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'dataflow', 'v1b3', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(mock_build.return_value, result)
+ assert mock_build.return_value == result
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@@ -220,7 +221,7 @@ def test_start_python_dataflow(self, mock_conn, mock_dataflow, mock_dataflowjob,
'--staging_location=gs://test/staging',
f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
]
- self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd))
+ assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@@ -254,7 +255,7 @@ def test_start_python_dataflow_with_custom_region_as_variable(
'--staging_location=gs://test/staging',
f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
]
- self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd))
+ assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@@ -287,7 +288,7 @@ def test_start_python_dataflow_with_custom_region_as_parameter(
'--staging_location=gs://test/staging',
f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
]
- self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd))
+ assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@@ -324,7 +325,7 @@ def test_start_python_dataflow_with_multiple_extra_packages(
'--staging_location=gs://test/staging',
f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
]
- self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd))
+ assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
@parameterized.expand(
[
@@ -372,7 +373,7 @@ def test_start_python_dataflow_with_custom_interpreter(
'--staging_location=gs://test/staging',
f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
]
- self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd))
+ assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
@parameterized.expand(
[
@@ -422,7 +423,7 @@ def test_start_python_dataflow_with_non_empty_py_requirements_and_without_system
'--staging_location=gs://test/staging',
f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
]
- self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd))
+ assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@@ -437,7 +438,7 @@ def test_start_python_dataflow_with_empty_py_requirements_and_without_system_pac
dataflow_instance.wait_for_done.return_value = None
dataflowjob_instance = mock_dataflowjob.return_value
dataflowjob_instance.wait_for_done.return_value = None
- with self.assertRaisesRegex(AirflowException, "Invalid method invocation."):
+ with pytest.raises(AirflowException, match="Invalid method invocation."):
self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
job_name=JOB_NAME,
variables=DATAFLOW_VARIABLES_PY,
@@ -471,10 +472,7 @@ def test_start_java_dataflow(self, mock_conn, mock_dataflow, mock_dataflowjob, m
'--labels={"foo":"bar"}',
f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
]
- self.assertListEqual(
- sorted(expected_cmd),
- sorted(mock_dataflow.call_args[1]["cmd"]),
- )
+ assert sorted(expected_cmd) == sorted(mock_dataflow.call_args[1]["cmd"])
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@@ -508,7 +506,7 @@ def test_start_java_dataflow_with_multiple_values_in_variables(
'--labels={"foo":"bar"}',
f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
]
- self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd))
+ assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@@ -541,10 +539,7 @@ def test_start_java_dataflow_with_custom_region_as_variable(
'--labels={"foo":"bar"}',
f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
]
- self.assertListEqual(
- sorted(expected_cmd),
- sorted(mock_dataflow.call_args[1]["cmd"]),
- )
+ assert sorted(expected_cmd) == sorted(mock_dataflow.call_args[1]["cmd"])
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@@ -577,10 +572,7 @@ def test_start_java_dataflow_with_custom_region_as_parameter(
'--labels={"foo":"bar"}',
f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
]
- self.assertListEqual(
- sorted(expected_cmd),
- sorted(mock_dataflow.call_args[1]["cmd"]),
- )
+ assert sorted(expected_cmd) == sorted(mock_dataflow.call_args[1]["cmd"])
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@@ -608,7 +600,7 @@ def test_start_java_dataflow_with_job_class(self, mock_conn, mock_dataflow, mock
'--labels={"foo":"bar"}',
f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
]
- self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd))
+ assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
@parameterized.expand(
[
@@ -628,13 +620,12 @@ def test_valid_dataflow_job_name(self, expected_result, job_name, append_job_nam
job_name=job_name, append_job_name=append_job_name
)
- self.assertEqual(expected_result, job_name)
+ assert expected_result == job_name
@parameterized.expand([("1dfjob@",), ("dfjob@",), ("df^jo",)])
def test_build_dataflow_job_name_with_invalid_value(self, job_name):
- self.assertRaises(
- ValueError, self.dataflow_hook._build_dataflow_job_name, job_name=job_name, append_job_name=False
- )
+ with pytest.raises(ValueError):
+ self.dataflow_hook._build_dataflow_job_name(job_name=job_name, append_job_name=False)
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
@@ -966,7 +957,7 @@ def test_start_flex_template(self, mock_conn, mock_controller):
mock_controller.return_value.get_jobs.wait_for_done.assrt_called_once_with()
mock_controller.return_value.get_jobs.assrt_called_once_with()
- self.assertEqual(result, {"id": TEST_JOB_ID})
+ assert result == {"id": TEST_JOB_ID}
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
@@ -1037,7 +1028,7 @@ def test_start_sql_job_failed_to_run(
drain_pipeline=False,
)
mock_controller.return_value.wait_for_done.assert_called_once()
- self.assertEqual(result, test_job)
+ assert result == test_job
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
@mock.patch(DATAFLOW_STRING.format('DataflowHook.provide_authorized_gcloud'))
@@ -1046,7 +1037,7 @@ def test_start_sql_job(self, mock_run, mock_provide_authorized_gcloud, mock_get_
mock_run.return_value = mock.MagicMock(
stdout=f"{TEST_JOB_ID}\n".encode(), stderr=f"{TEST_JOB_ID}\n".encode(), returncode=1
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.dataflow_hook.start_sql_job(
job_name=TEST_SQL_JOB_NAME,
query=TEST_SQL_QUERY,
@@ -1134,7 +1125,7 @@ def test_dataflow_job_wait_for_multiple_jobs(self):
.return_value.execute.assert_called_once_with(num_retries=20)
# fmt: on
- self.assertEqual(dataflow_job.get_jobs(), [job, job])
+ assert dataflow_job.get_jobs() == [job, job]
@parameterized.expand(
[
@@ -1187,7 +1178,7 @@ def test_dataflow_job_wait_for_multiple_jobs_and_one_in_terminal_state(self, sta
num_retries=20,
multiple_jobs=True,
)
- with self.assertRaisesRegex(Exception, exception_regex):
+ with pytest.raises(Exception, match=exception_regex):
dataflow_job.wait_for_done()
def test_dataflow_job_wait_for_multiple_jobs_and_streaming_jobs(self):
@@ -1227,7 +1218,7 @@ def test_dataflow_job_wait_for_multiple_jobs_and_streaming_jobs(self):
)
dataflow_job.wait_for_done()
- self.assertEqual(1, mock_jobs_list.call_count)
+ assert 1 == mock_jobs_list.call_count
def test_dataflow_job_wait_for_single_jobs(self):
job = {
@@ -1269,7 +1260,7 @@ def test_dataflow_job_wait_for_single_jobs(self):
self.mock_dataflow.projects.return_value.locations.return_value. \
jobs.return_value.get.return_value.execute.assert_called_once_with(num_retries=20)
# fmt: on
- self.assertEqual(dataflow_job.get_jobs(), [job])
+ assert dataflow_job.get_jobs() == [job]
def test_dataflow_job_is_job_running_with_no_job(self):
# fmt: off
@@ -1301,7 +1292,7 @@ def test_dataflow_job_is_job_running_with_no_job(self):
)
result = dataflow_job.is_job_running()
- self.assertEqual(False, result)
+ assert result is False
# fmt: off
@parameterized.expand([
@@ -1337,7 +1328,7 @@ def test_check_dataflow_job_state_wait_until_finished(
wait_until_finished=wait_until_finished,
)
result = dataflow_job._check_dataflow_job_state(job)
- self.assertEqual(result, expected_result)
+ assert result == expected_result
# fmt: off
@parameterized.expand([
@@ -1375,7 +1366,7 @@ def test_check_dataflow_job_state_terminal_state(self, job_type, job_state, exce
num_retries=20,
multiple_jobs=True,
)
- with self.assertRaisesRegex(Exception, exception_regex):
+ with pytest.raises(Exception, match=exception_regex):
dataflow_job._check_dataflow_job_state(job)
def test_dataflow_job_cancel_job(self):
@@ -1582,7 +1573,7 @@ def test_fetch_list_job_messages_responses(self):
mock_list_next.assert_called_once_with(
previous_request=mock_list.return_value, previous_response="response_1"
)
- self.assertEqual(result, ["response_1"])
+ assert result == ["response_1"]
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController._fetch_list_job_messages_responses'))
def test_fetch_job_messages_by_id(self, mock_fetch_responses):
@@ -1600,7 +1591,7 @@ def test_fetch_job_messages_by_id(self, mock_fetch_responses):
)
result = jobs_controller.fetch_job_messages_by_id(TEST_JOB_ID)
mock_fetch_responses.assert_called_once_with(job_id=TEST_JOB_ID)
- self.assertEqual(result, ['message_1', 'message_2'])
+ assert result == ['message_1', 'message_2']
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController._fetch_list_job_messages_responses'))
def test_fetch_job_autoscaling_events_by_id(self, mock_fetch_responses):
@@ -1618,7 +1609,7 @@ def test_fetch_job_autoscaling_events_by_id(self, mock_fetch_responses):
)
result = jobs_controller.fetch_job_autoscaling_events_by_id(TEST_JOB_ID)
mock_fetch_responses.assert_called_once_with(job_id=TEST_JOB_ID)
- self.assertEqual(result, ['event_1', 'event_2'])
+ assert result == ['event_1', 'event_2']
APACHE_BEAM_V_2_14_0_JAVA_SDK_LOG = f""""\
@@ -1700,11 +1691,11 @@ class TestDataflow(unittest.TestCase):
def test_data_flow_valid_job_id(self, log):
echos = ";".join([f"echo {shlex.quote(line)}" for line in log.split("\n")])
cmd = ["bash", "-c", echos]
- self.assertEqual(_DataflowRunner(cmd).wait_for_done(), TEST_JOB_ID)
+ assert _DataflowRunner(cmd).wait_for_done() == TEST_JOB_ID
def test_data_flow_missing_job_id(self):
cmd = ['echo', 'unit testing']
- self.assertEqual(_DataflowRunner(cmd).wait_for_done(), None)
+ assert _DataflowRunner(cmd).wait_for_done() is None
@mock.patch('airflow.providers.google.cloud.hooks.dataflow._DataflowRunner.log')
@mock.patch('subprocess.Popen')
@@ -1729,4 +1720,5 @@ def poll_resp_error():
mock_popen.return_value = mock_proc
dataflow = _DataflowRunner(['test', 'cmd'])
mock_logging.info.assert_called_once_with('Running command: %s', 'test cmd')
- self.assertRaises(Exception, dataflow.wait_for_done)
+ with pytest.raises(Exception):
+ dataflow.wait_for_done()
diff --git a/tests/providers/google/cloud/hooks/test_dataprep.py b/tests/providers/google/cloud/hooks/test_dataprep.py
index 6ab55fb6b030b..d00e540bc10d2 100644
--- a/tests/providers/google/cloud/hooks/test_dataprep.py
+++ b/tests/providers/google/cloud/hooks/test_dataprep.py
@@ -83,11 +83,11 @@ def test_get_jobs_for_job_group_should_retry_after_four_errors(self, mock_get_re
side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), HTTPError()],
)
def test_get_jobs_for_job_group_raise_error_after_five_calls(self, mock_get_request):
- with pytest.raises(RetryError) as err:
+ with pytest.raises(RetryError) as ctx:
# pylint: disable=no-member
self.hook.get_jobs_for_job_group.retry.sleep = mock.Mock()
self.hook.get_jobs_for_job_group(JOB_ID)
- assert "HTTPError" in str(err)
+ assert "HTTPError" in str(ctx.value)
assert mock_get_request.call_count == 5
@patch("airflow.providers.google.cloud.hooks.dataprep.requests.get")
@@ -139,11 +139,11 @@ def test_get_job_group_should_retry_after_four_errors(self, mock_get_request):
side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), HTTPError()],
)
def test_get_job_group_raise_error_after_five_calls(self, mock_get_request):
- with pytest.raises(RetryError) as err:
+ with pytest.raises(RetryError) as ctx:
# pylint: disable=no-member
self.hook.get_job_group.retry.sleep = mock.Mock()
self.hook.get_job_group(JOB_ID, EMBED, INCLUDE_DELETED)
- assert "HTTPError" in str(err)
+ assert "HTTPError" in str(ctx.value)
assert mock_get_request.call_count == 5
@patch("airflow.providers.google.cloud.hooks.dataprep.requests.post")
@@ -196,9 +196,9 @@ def test_run_job_group_should_retry_after_four_errors(self, mock_get_request):
side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), HTTPError()],
)
def test_run_job_group_raise_error_after_five_calls(self, mock_get_request):
- with pytest.raises(RetryError) as err:
+ with pytest.raises(RetryError) as ctx:
# pylint: disable=no-member
self.hook.run_job_group.retry.sleep = mock.Mock()
self.hook.run_job_group(body_request=DATA)
- assert "HTTPError" in str(err)
+ assert "HTTPError" in str(ctx.value)
assert mock_get_request.call_count == 5
diff --git a/tests/providers/google/cloud/hooks/test_dataproc.py b/tests/providers/google/cloud/hooks/test_dataproc.py
index 20698791cb7b9..d09c91e13a663 100644
--- a/tests/providers/google/cloud/hooks/test_dataproc.py
+++ b/tests/providers/google/cloud/hooks/test_dataproc.py
@@ -19,6 +19,7 @@
import unittest
from unittest import mock
+import pytest
from google.cloud.dataproc_v1beta2.types import JobStatus # pylint: disable=no-name-in-module
from airflow.exceptions import AirflowException
@@ -245,7 +246,7 @@ def test_wait_for_job(self, mock_get_job):
mock.MagicMock(status=mock.MagicMock(state=JobStatus.RUNNING)),
mock.MagicMock(status=mock.MagicMock(state=JobStatus.ERROR)),
]
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.wait_for_job(job_id=JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT, wait_time=0)
calls = [
mock.call(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT),
@@ -284,7 +285,7 @@ def test_submit_job(self, mock_client):
@mock.patch(DATAPROC_STRING.format("DataprocHook.submit_job"))
def test_submit(self, mock_submit_job, mock_wait_for_job):
mock_submit_job.return_value.reference.job_id = JOB_ID
- with self.assertWarns(DeprecationWarning):
+ with pytest.warns(DeprecationWarning):
self.hook.submit(project_id=GCP_PROJECT, job=JOB, region=GCP_LOCATION)
mock_submit_job.assert_called_once_with(location=GCP_LOCATION, project_id=GCP_PROJECT, job=JOB)
mock_wait_for_job.assert_called_once_with(
@@ -306,7 +307,7 @@ def test_cancel_job(self, mock_client):
@mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client"))
def test_cancel_job_deprecation_warning(self, mock_client):
- with self.assertWarns(DeprecationWarning):
+ with pytest.warns(DeprecationWarning):
self.hook.cancel_job(job_id=JOB_ID, project_id=GCP_PROJECT)
mock_client.assert_called_once_with(location='global')
mock_client.return_value.cancel_job.assert_called_once_with(
@@ -350,72 +351,72 @@ def test_init(self, mock_uuid):
properties=properties,
)
- self.assertDictEqual(job, builder.job)
+ assert job == builder.job
def test_add_labels(self):
labels = {"key": "value"}
self.builder.add_labels(labels)
- self.assertIn("key", self.builder.job["job"]["labels"])
- self.assertEqual("value", self.builder.job["job"]["labels"]["key"])
+ assert "key" in self.builder.job["job"]["labels"]
+ assert "value" == self.builder.job["job"]["labels"]["key"]
def test_add_variables(self):
variables = ["variable"]
self.builder.add_variables(variables)
- self.assertEqual(variables, self.builder.job["job"][self.job_type]["script_variables"])
+ assert variables == self.builder.job["job"][self.job_type]["script_variables"]
def test_add_args(self):
args = ["args"]
self.builder.add_args(args)
- self.assertEqual(args, self.builder.job["job"][self.job_type]["args"])
+ assert args == self.builder.job["job"][self.job_type]["args"]
def test_add_query(self):
query = ["query"]
self.builder.add_query(query)
- self.assertEqual({"queries": [query]}, self.builder.job["job"][self.job_type]["query_list"])
+ assert {"queries": [query]} == self.builder.job["job"][self.job_type]["query_list"]
def test_add_query_uri(self):
query_uri = "query_uri"
self.builder.add_query_uri(query_uri)
- self.assertEqual(query_uri, self.builder.job["job"][self.job_type]["query_file_uri"])
+ assert query_uri == self.builder.job["job"][self.job_type]["query_file_uri"]
def test_add_jar_file_uris(self):
jar_file_uris = ["jar_file_uris"]
self.builder.add_jar_file_uris(jar_file_uris)
- self.assertEqual(jar_file_uris, self.builder.job["job"][self.job_type]["jar_file_uris"])
+ assert jar_file_uris == self.builder.job["job"][self.job_type]["jar_file_uris"]
def test_add_archive_uris(self):
archive_uris = ["archive_uris"]
self.builder.add_archive_uris(archive_uris)
- self.assertEqual(archive_uris, self.builder.job["job"][self.job_type]["archive_uris"])
+ assert archive_uris == self.builder.job["job"][self.job_type]["archive_uris"]
def test_add_file_uris(self):
file_uris = ["file_uris"]
self.builder.add_file_uris(file_uris)
- self.assertEqual(file_uris, self.builder.job["job"][self.job_type]["file_uris"])
+ assert file_uris == self.builder.job["job"][self.job_type]["file_uris"]
def test_add_python_file_uris(self):
python_file_uris = ["python_file_uris"]
self.builder.add_python_file_uris(python_file_uris)
- self.assertEqual(python_file_uris, self.builder.job["job"][self.job_type]["python_file_uris"])
+ assert python_file_uris == self.builder.job["job"][self.job_type]["python_file_uris"]
def test_set_main_error(self):
- with self.assertRaises(Exception):
+ with pytest.raises(Exception):
self.builder.set_main("test", "test")
def test_set_main_class(self):
main = "main"
self.builder.set_main(main_class=main, main_jar=None)
- self.assertEqual(main, self.builder.job["job"][self.job_type]["main_class"])
+ assert main == self.builder.job["job"][self.job_type]["main_class"]
def test_set_main_jar(self):
main = "main"
self.builder.set_main(main_class=None, main_jar=main)
- self.assertEqual(main, self.builder.job["job"][self.job_type]["main_jar_file_uri"])
+ assert main == self.builder.job["job"][self.job_type]["main_jar_file_uri"]
def test_set_python_main(self):
main = "main"
self.builder.set_python_main(main)
- self.assertEqual(main, self.builder.job["job"][self.job_type]["main_python_file_uri"])
+ assert main == self.builder.job["job"][self.job_type]["main_python_file_uri"]
@mock.patch(DATAPROC_STRING.format("uuid.uuid4"))
def test_set_job_name(self, mock_uuid):
@@ -424,7 +425,7 @@ def test_set_job_name(self, mock_uuid):
name = "name"
self.builder.set_job_name(name)
name += "_" + uuid[:8]
- self.assertEqual(name, self.builder.job["job"]["reference"]["job_id"])
+ assert name == self.builder.job["job"]["reference"]["job_id"]
def test_build(self):
- self.assertEqual(self.builder.job, self.builder.build())
+ assert self.builder.job == self.builder.build()
diff --git a/tests/providers/google/cloud/hooks/test_datastore.py b/tests/providers/google/cloud/hooks/test_datastore.py
index 3d9216ab9daee..6c16773ff65da 100644
--- a/tests/providers/google/cloud/hooks/test_datastore.py
+++ b/tests/providers/google/cloud/hooks/test_datastore.py
@@ -21,6 +21,8 @@
from unittest import mock
from unittest.mock import call, patch
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.datastore import DatastoreHook
@@ -51,8 +53,8 @@ def test_get_conn(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'datastore', 'v1', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(conn, mock_build.return_value)
- self.assertEqual(conn, self.datastore_hook.connection)
+ assert conn == mock_build.return_value
+ assert conn == self.datastore_hook.connection
@patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn')
def test_allocate_ids(self, mock_get_conn):
@@ -67,7 +69,7 @@ def test_allocate_ids(self, mock_get_conn):
allocate_ids.assert_called_once_with(projectId=GCP_PROJECT_ID, body={'keys': partial_keys})
execute = allocate_ids.return_value.execute
execute.assert_called_once_with(num_retries=mock.ANY)
- self.assertEqual(keys, execute.return_value['keys'])
+ assert keys == execute.return_value['keys']
@patch(
'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id',
@@ -79,11 +81,11 @@ def test_allocate_ids_no_project_id(self, mock_get_conn, mock_project_id):
self.datastore_hook.connection = mock_get_conn.return_value
partial_keys = []
- with self.assertRaises(AirflowException) as err:
+ with pytest.raises(AirflowException) as ctx:
self.datastore_hook.allocate_ids( # pylint: disable=no-value-for-parameter
partial_keys=partial_keys
)
- self.assertIn("project_id", str(err.exception))
+ assert "project_id" in str(ctx.value)
@patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn')
def test_begin_transaction(self, mock_get_conn):
@@ -100,7 +102,7 @@ def test_begin_transaction(self, mock_get_conn):
begin_transaction.assert_called_once_with(projectId=GCP_PROJECT_ID, body={'transactionOptions': {}})
execute = begin_transaction.return_value.execute
execute.assert_called_once_with(num_retries=mock.ANY)
- self.assertEqual(transaction, execute.return_value['transaction'])
+ assert transaction == execute.return_value['transaction']
@patch(
'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id',
@@ -110,9 +112,9 @@ def test_begin_transaction(self, mock_get_conn):
@patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn')
def test_begin_transaction_no_project_id(self, mock_get_conn, mock_project_id):
self.datastore_hook.connection = mock_get_conn.return_value
- with self.assertRaises(AirflowException) as err:
+ with pytest.raises(AirflowException) as ctx:
self.datastore_hook.begin_transaction() # pylint: disable=no-value-for-parameter
- self.assertIn("project_id", str(err.exception))
+ assert "project_id" in str(ctx.value)
@patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn')
def test_commit(self, mock_get_conn):
@@ -127,7 +129,7 @@ def test_commit(self, mock_get_conn):
commit.assert_called_once_with(projectId=GCP_PROJECT_ID, body=body)
execute = commit.return_value.execute
execute.assert_called_once_with(num_retries=mock.ANY)
- self.assertEqual(resp, execute.return_value)
+ assert resp == execute.return_value
@patch(
'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id',
@@ -139,9 +141,9 @@ def test_commit_no_project_id(self, mock_get_conn, mock_project_id):
self.datastore_hook.connection = mock_get_conn.return_value
body = {'item': 'a'}
- with self.assertRaises(AirflowException) as err:
+ with pytest.raises(AirflowException) as ctx:
self.datastore_hook.commit(body=body) # pylint: disable=no-value-for-parameter
- self.assertIn("project_id", str(err.exception))
+ assert "project_id" in str(ctx.value)
@patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn')
def test_lookup(self, mock_get_conn):
@@ -163,7 +165,7 @@ def test_lookup(self, mock_get_conn):
)
execute = lookup.return_value.execute
execute.assert_called_once_with(num_retries=mock.ANY)
- self.assertEqual(resp, execute.return_value)
+ assert resp == execute.return_value
@patch(
'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id',
@@ -177,13 +179,13 @@ def test_lookup_no_project_id(self, mock_get_conn, mock_project_id):
read_consistency = 'ENUM'
transaction = 'transaction'
- with self.assertRaises(AirflowException) as err:
+ with pytest.raises(AirflowException) as ctx:
self.datastore_hook.lookup( # pylint: disable=no-value-for-parameter
keys=keys,
read_consistency=read_consistency,
transaction=transaction,
)
- self.assertIn("project_id", str(err.exception))
+ assert "project_id" in str(ctx.value)
@patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn')
def test_rollback(self, mock_get_conn):
@@ -209,9 +211,9 @@ def test_rollback_no_project_id(self, mock_get_conn, mock_project_id):
self.datastore_hook.connection = mock_get_conn.return_value
transaction = 'transaction'
- with self.assertRaises(AirflowException) as err:
+ with pytest.raises(AirflowException) as ctx:
self.datastore_hook.rollback(transaction=transaction) # pylint: disable=no-value-for-parameter
- self.assertIn("project_id", str(err.exception))
+ assert "project_id" in str(ctx.value)
@patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn')
def test_run_query(self, mock_get_conn):
@@ -226,7 +228,7 @@ def test_run_query(self, mock_get_conn):
run_query.assert_called_once_with(projectId=GCP_PROJECT_ID, body=body)
execute = run_query.return_value.execute
execute.assert_called_once_with(num_retries=mock.ANY)
- self.assertEqual(resp, execute.return_value['batch'])
+ assert resp == execute.return_value['batch']
@patch(
'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id',
@@ -238,9 +240,9 @@ def test_run_query_no_project_id(self, mock_get_conn, mock_project_id):
self.datastore_hook.connection = mock_get_conn.return_value
body = {'item': 'a'}
- with self.assertRaises(AirflowException) as err:
+ with pytest.raises(AirflowException) as ctx:
self.datastore_hook.run_query(body=body) # pylint: disable=no-value-for-parameter
- self.assertIn("project_id", str(err.exception))
+ assert "project_id" in str(ctx.value)
@patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn')
def test_get_operation(self, mock_get_conn):
@@ -257,7 +259,7 @@ def test_get_operation(self, mock_get_conn):
get.assert_called_once_with(name=name)
execute = get.return_value.execute
execute.assert_called_once_with(num_retries=mock.ANY)
- self.assertEqual(resp, execute.return_value)
+ assert resp == execute.return_value
@patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn')
def test_delete_operation(self, mock_get_conn):
@@ -274,7 +276,7 @@ def test_delete_operation(self, mock_get_conn):
delete.assert_called_once_with(name=name)
execute = delete.return_value.execute
execute.assert_called_once_with(num_retries=mock.ANY)
- self.assertEqual(resp, execute.return_value)
+ assert resp == execute.return_value
@patch('airflow.providers.google.cloud.hooks.datastore.time.sleep')
@patch(
@@ -292,7 +294,7 @@ def test_poll_operation_until_done(self, mock_get_operation, mock_time_sleep):
mock_get_operation.assert_has_calls([call(name), call(name)])
mock_time_sleep.assert_called_once_with(polling_interval_in_seconds)
- self.assertEqual(result, {'metadata': {'common': {'state': 'NOT PROCESSING'}}})
+ assert result == {'metadata': {'common': {'state': 'NOT PROCESSING'}}}
@patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn')
def test_export_to_storage_bucket(self, mock_get_conn):
@@ -323,7 +325,7 @@ def test_export_to_storage_bucket(self, mock_get_conn):
)
execute = export.return_value.execute
execute.assert_called_once_with(num_retries=mock.ANY)
- self.assertEqual(resp, execute.return_value)
+ assert resp == execute.return_value
@patch(
'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id',
@@ -338,14 +340,14 @@ def test_export_to_storage_bucket_no_project_id(self, mock_get_conn, mock_projec
entity_filter = {}
labels = {}
- with self.assertRaises(AirflowException) as err:
+ with pytest.raises(AirflowException) as ctx:
self.datastore_hook.export_to_storage_bucket( # pylint: disable=no-value-for-parameter
bucket=bucket,
namespace=namespace,
entity_filter=entity_filter,
labels=labels,
)
- self.assertIn("project_id", str(err.exception))
+ assert "project_id" in str(ctx.value)
@patch('airflow.providers.google.cloud.hooks.datastore.DatastoreHook.get_conn')
def test_import_from_storage_bucket(self, mock_get_conn):
@@ -378,7 +380,7 @@ def test_import_from_storage_bucket(self, mock_get_conn):
)
execute = import_.return_value.execute
execute.assert_called_once_with(num_retries=mock.ANY)
- self.assertEqual(resp, execute.return_value)
+ assert resp == execute.return_value
@patch(
'airflow.providers.google.cloud.hooks.datastore.DatastoreHook.project_id',
@@ -394,7 +396,7 @@ def test_import_from_storage_bucket_no_project_id(self, mock_get_conn, mock_proj
entity_filter = {}
labels = {}
- with self.assertRaises(AirflowException) as err:
+ with pytest.raises(AirflowException) as ctx:
self.datastore_hook.import_from_storage_bucket( # pylint: disable=no-value-for-parameter
bucket=bucket,
file=file,
@@ -402,4 +404,4 @@ def test_import_from_storage_bucket_no_project_id(self, mock_get_conn, mock_proj
entity_filter=entity_filter,
labels=labels,
)
- self.assertIn("project_id", str(err.exception))
+ assert "project_id" in str(ctx.value)
diff --git a/tests/providers/google/cloud/hooks/test_dlp.py b/tests/providers/google/cloud/hooks/test_dlp.py
index c9f14d99e81cf..5aac84a75a09c 100644
--- a/tests/providers/google/cloud/hooks/test_dlp.py
+++ b/tests/providers/google/cloud/hooks/test_dlp.py
@@ -27,6 +27,7 @@
from unittest import mock
from unittest.mock import PropertyMock
+import pytest
from google.cloud.dlp_v2.types import DlpJob
from airflow.exceptions import AirflowException
@@ -76,8 +77,8 @@ def test_dlp_service_client_creation(self, mock_client, mock_get_creds, mock_cli
mock_client.assert_called_once_with(
credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.hook._client, result)
+ assert mock_client.return_value == result
+ assert self.hook._client == result
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_cancel_dlp_job(self, get_conn):
@@ -89,7 +90,7 @@ def test_cancel_dlp_job(self, get_conn):
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_cancel_dlp_job_without_dlp_job_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.cancel_dlp_job(dlp_job_id=None, project_id=PROJECT_ID)
@mock.patch(
@@ -99,7 +100,7 @@ def test_cancel_dlp_job_without_dlp_job_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_cancel_dlp_job_without_parent(self, _, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.cancel_dlp_job(dlp_job_id=DLP_JOB_ID) # pylint: disable=no-value-for-parameter
@mock.patch(
@@ -112,7 +113,7 @@ def test_create_deidentify_template_with_org_id(self, get_conn, mock_project_id)
get_conn.return_value.create_deidentify_template.return_value = API_RESPONSE
result = self.hook.create_deidentify_template(organization_id=ORGANIZATION_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.create_deidentify_template.assert_called_once_with(
parent=ORGANIZATION_PATH,
deidentify_template=None,
@@ -127,7 +128,7 @@ def test_create_deidentify_template_with_project_id(self, get_conn):
get_conn.return_value.create_deidentify_template.return_value = API_RESPONSE
result = self.hook.create_deidentify_template(project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.create_deidentify_template.assert_called_once_with(
parent=PROJECT_PATH,
deidentify_template=None,
@@ -144,7 +145,7 @@ def test_create_deidentify_template_with_project_id(self, get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_create_deidentify_template_without_parent(self, _, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.create_deidentify_template()
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
@@ -152,7 +153,7 @@ def test_create_dlp_job(self, get_conn):
get_conn.return_value.create_dlp_job.return_value = API_RESPONSE
result = self.hook.create_dlp_job(project_id=PROJECT_ID, wait_until_finished=False)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.create_dlp_job.assert_called_once_with(
parent=PROJECT_PATH,
inspect_job=None,
@@ -170,7 +171,7 @@ def test_create_dlp_job(self, get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_create_dlp_job_without_project_id(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.create_dlp_job() # pylint: disable=no-value-for-parameter
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
@@ -196,7 +197,7 @@ def test_create_inspect_template_with_org_id(self, get_conn, mock_project_id):
get_conn.return_value.create_inspect_template.return_value = API_RESPONSE
result = self.hook.create_inspect_template(organization_id=ORGANIZATION_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.create_inspect_template.assert_called_once_with(
parent=ORGANIZATION_PATH,
inspect_template=None,
@@ -211,7 +212,7 @@ def test_create_inspect_template_with_project_id(self, get_conn):
get_conn.return_value.create_inspect_template.return_value = API_RESPONSE
result = self.hook.create_inspect_template(project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.create_inspect_template.assert_called_once_with(
parent=PROJECT_PATH,
inspect_template=None,
@@ -228,7 +229,7 @@ def test_create_inspect_template_with_project_id(self, get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_create_inspect_template_without_parent(self, _, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.create_inspect_template()
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
@@ -236,7 +237,7 @@ def test_create_job_trigger(self, get_conn):
get_conn.return_value.create_job_trigger.return_value = API_RESPONSE
result = self.hook.create_job_trigger(project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.create_job_trigger.assert_called_once_with(
parent=PROJECT_PATH,
job_trigger=None,
@@ -253,7 +254,7 @@ def test_create_job_trigger(self, get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_create_job_trigger_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.create_job_trigger() # pylint: disable=no-value-for-parameter
@mock.patch(
@@ -266,7 +267,7 @@ def test_create_stored_info_type_with_org_id(self, get_conn, mock_project_id):
get_conn.return_value.create_stored_info_type.return_value = API_RESPONSE
result = self.hook.create_stored_info_type(organization_id=ORGANIZATION_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.create_stored_info_type.assert_called_once_with(
parent=ORGANIZATION_PATH,
config=None,
@@ -281,7 +282,7 @@ def test_create_stored_info_type_with_project_id(self, get_conn):
get_conn.return_value.create_stored_info_type.return_value = API_RESPONSE
result = self.hook.create_stored_info_type(project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.create_stored_info_type.assert_called_once_with(
parent=PROJECT_PATH,
config=None,
@@ -298,7 +299,7 @@ def test_create_stored_info_type_with_project_id(self, get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_create_stored_info_type_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.create_stored_info_type()
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
@@ -306,7 +307,7 @@ def test_deidentify_content(self, get_conn):
get_conn.return_value.deidentify_content.return_value = API_RESPONSE
result = self.hook.deidentify_content(project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.deidentify_content.assert_called_once_with(
parent=PROJECT_PATH,
deidentify_config=None,
@@ -326,7 +327,7 @@ def test_deidentify_content(self, get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_deidentify_content_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.deidentify_content() # pylint: disable=no-value-for-parameter
@mock.patch(
@@ -358,7 +359,7 @@ def test_delete_deidentify_template_with_project_id(self, get_conn):
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_delete_deidentify_template_without_template_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.delete_deidentify_template(template_id=None)
@mock.patch(
@@ -368,7 +369,7 @@ def test_delete_deidentify_template_without_template_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_delete_deidentify_template_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.delete_deidentify_template(template_id=TEMPLATE_ID)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
@@ -381,7 +382,7 @@ def test_delete_dlp_job(self, get_conn):
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_delete_dlp_job_without_dlp_job_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.delete_dlp_job(dlp_job_id=None, project_id=PROJECT_ID)
@mock.patch(
@@ -391,7 +392,7 @@ def test_delete_dlp_job_without_dlp_job_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_delete_dlp_job_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.delete_dlp_job(dlp_job_id=DLP_JOB_ID) # pylint: disable=no-value-for-parameter
@mock.patch(
@@ -423,7 +424,7 @@ def test_delete_inspect_template_with_project_id(self, get_conn):
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_delete_inspect_template_without_template_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.delete_inspect_template(template_id=None)
@mock.patch(
@@ -433,7 +434,7 @@ def test_delete_inspect_template_without_template_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_delete_inspect_template_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.delete_inspect_template(template_id=TEMPLATE_ID)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
@@ -446,7 +447,7 @@ def test_delete_job_trigger(self, get_conn):
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_delete_job_trigger_without_trigger_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.delete_job_trigger(job_trigger_id=None, project_id=PROJECT_ID)
@mock.patch(
@@ -456,7 +457,7 @@ def test_delete_job_trigger_without_trigger_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_delete_job_trigger_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.delete_job_trigger(job_trigger_id=TRIGGER_ID) # pylint: disable=no-value-for-parameter
@mock.patch(
@@ -490,7 +491,7 @@ def test_delete_stored_info_type_with_project_id(self, get_conn):
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_delete_stored_info_type_without_stored_info_type_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.delete_stored_info_type(stored_info_type_id=None)
@mock.patch(
@@ -500,7 +501,7 @@ def test_delete_stored_info_type_without_stored_info_type_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_delete_stored_info_type_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.delete_stored_info_type(stored_info_type_id=STORED_INFO_TYPE_ID)
@mock.patch(
@@ -513,7 +514,7 @@ def test_get_deidentify_template_with_org_id(self, get_conn, mock_project_id):
get_conn.return_value.get_deidentify_template.return_value = API_RESPONSE
result = self.hook.get_deidentify_template(template_id=TEMPLATE_ID, organization_id=ORGANIZATION_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.get_deidentify_template.assert_called_once_with(
name=DEIDENTIFY_TEMPLATE_ORGANIZATION_PATH,
retry=None,
@@ -526,7 +527,7 @@ def test_get_deidentify_template_with_project_id(self, get_conn):
get_conn.return_value.get_deidentify_template.return_value = API_RESPONSE
result = self.hook.get_deidentify_template(template_id=TEMPLATE_ID, project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.get_deidentify_template.assert_called_once_with(
name=DEIDENTIFY_TEMPLATE_PROJECT_PATH,
retry=None,
@@ -536,7 +537,7 @@ def test_get_deidentify_template_with_project_id(self, get_conn):
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_get_deidentify_template_without_template_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.get_deidentify_template(template_id=None)
@mock.patch(
@@ -546,7 +547,7 @@ def test_get_deidentify_template_without_template_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_get_deidentify_template_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.get_deidentify_template(template_id=TEMPLATE_ID)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
@@ -554,14 +555,14 @@ def test_get_dlp_job(self, get_conn):
get_conn.return_value.get_dlp_job.return_value = API_RESPONSE
result = self.hook.get_dlp_job(dlp_job_id=DLP_JOB_ID, project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.get_dlp_job.assert_called_once_with(
name=DLP_JOB_PATH, retry=None, timeout=None, metadata=None
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_get_dlp_job_without_dlp_job_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.get_dlp_job(dlp_job_id=None, project_id=PROJECT_ID)
@mock.patch(
@@ -571,7 +572,7 @@ def test_get_dlp_job_without_dlp_job_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_get_dlp_job_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.get_dlp_job(dlp_job_id=DLP_JOB_ID) # pylint: disable=no-value-for-parameter
@mock.patch(
@@ -584,7 +585,7 @@ def test_get_inspect_template_with_org_id(self, get_conn, mock_project_id):
get_conn.return_value.get_inspect_template.return_value = API_RESPONSE
result = self.hook.get_inspect_template(template_id=TEMPLATE_ID, organization_id=ORGANIZATION_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.get_inspect_template.assert_called_once_with(
name=INSPECT_TEMPLATE_ORGANIZATION_PATH,
retry=None,
@@ -597,7 +598,7 @@ def test_get_inspect_template_with_project_id(self, get_conn):
get_conn.return_value.get_inspect_template.return_value = API_RESPONSE
result = self.hook.get_inspect_template(template_id=TEMPLATE_ID, project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.get_inspect_template.assert_called_once_with(
name=INSPECT_TEMPLATE_PROJECT_PATH,
retry=None,
@@ -607,7 +608,7 @@ def test_get_inspect_template_with_project_id(self, get_conn):
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_get_inspect_template_without_template_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.get_inspect_template(template_id=None)
@mock.patch(
@@ -617,7 +618,7 @@ def test_get_inspect_template_without_template_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_get_inspect_template_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.get_inspect_template(template_id=TEMPLATE_ID)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
@@ -625,14 +626,14 @@ def test_get_job_trigger(self, get_conn):
get_conn.return_value.get_job_trigger.return_value = API_RESPONSE
result = self.hook.get_job_trigger(job_trigger_id=TRIGGER_ID, project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.get_job_trigger.assert_called_once_with(
name=JOB_TRIGGER_PATH, retry=None, timeout=None, metadata=None
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_get_job_trigger_without_trigger_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.get_job_trigger(job_trigger_id=None, project_id=PROJECT_ID)
@mock.patch(
@@ -642,7 +643,7 @@ def test_get_job_trigger_without_trigger_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_get_job_trigger_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.get_job_trigger(job_trigger_id=TRIGGER_ID) # pylint: disable=no-value-for-parameter
@mock.patch(
@@ -657,7 +658,7 @@ def test_get_stored_info_type_with_org_id(self, get_conn, mock_project_id):
stored_info_type_id=STORED_INFO_TYPE_ID, organization_id=ORGANIZATION_ID
)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.get_stored_info_type.assert_called_once_with(
name=STORED_INFO_TYPE_ORGANIZATION_PATH,
retry=None,
@@ -672,7 +673,7 @@ def test_get_stored_info_type_with_project_id(self, get_conn):
stored_info_type_id=STORED_INFO_TYPE_ID, project_id=PROJECT_ID
)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.get_stored_info_type.assert_called_once_with(
name=STORED_INFO_TYPE_PROJECT_PATH,
retry=None,
@@ -682,7 +683,7 @@ def test_get_stored_info_type_with_project_id(self, get_conn):
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_get_stored_info_type_without_stored_info_type_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.get_stored_info_type(stored_info_type_id=None)
@mock.patch(
@@ -692,7 +693,7 @@ def test_get_stored_info_type_without_stored_info_type_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_get_stored_info_type_without_parent(self, mock_get_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.get_stored_info_type(stored_info_type_id=STORED_INFO_TYPE_ID)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
@@ -700,7 +701,7 @@ def test_inspect_content(self, get_conn):
get_conn.return_value.inspect_content.return_value = API_RESPONSE
result = self.hook.inspect_content(project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.inspect_content.assert_called_once_with(
parent=PROJECT_PATH,
inspect_config=None,
@@ -718,7 +719,7 @@ def test_inspect_content(self, get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_inspect_content_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.inspect_content() # pylint: disable=no-value-for-parameter
@mock.patch(
@@ -730,7 +731,7 @@ def test_inspect_content_without_parent(self, mock_get_conn, mock_project_id):
def test_list_deidentify_templates_with_org_id(self, get_conn, mock_project_id):
result = self.hook.list_deidentify_templates(organization_id=ORGANIZATION_ID)
- self.assertIsInstance(result, list)
+ assert isinstance(result, list)
get_conn.return_value.list_deidentify_templates.assert_called_once_with(
parent=ORGANIZATION_PATH,
page_size=None,
@@ -744,7 +745,7 @@ def test_list_deidentify_templates_with_org_id(self, get_conn, mock_project_id):
def test_list_deidentify_templates_with_project_id(self, get_conn):
result = self.hook.list_deidentify_templates(project_id=PROJECT_ID)
- self.assertIsInstance(result, list)
+ assert isinstance(result, list)
get_conn.return_value.list_deidentify_templates.assert_called_once_with(
parent=PROJECT_PATH,
page_size=None,
@@ -761,14 +762,14 @@ def test_list_deidentify_templates_with_project_id(self, get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_list_deidentify_templates_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.list_deidentify_templates()
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_list_dlp_jobs(self, get_conn):
result = self.hook.list_dlp_jobs(project_id=PROJECT_ID)
- self.assertIsInstance(result, list)
+ assert isinstance(result, list)
get_conn.return_value.list_dlp_jobs.assert_called_once_with(
parent=PROJECT_PATH,
filter_=None,
@@ -787,7 +788,7 @@ def test_list_dlp_jobs(self, get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_list_dlp_jobs_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.list_dlp_jobs() # pylint: disable=no-value-for-parameter
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
@@ -795,7 +796,7 @@ def test_list_info_types(self, get_conn):
get_conn.return_value.list_info_types.return_value = API_RESPONSE
result = self.hook.list_info_types()
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.list_info_types.assert_called_once_with(
language_code=None, filter_=None, retry=None, timeout=None, metadata=None
)
@@ -809,7 +810,7 @@ def test_list_info_types(self, get_conn):
def test_list_inspect_templates_with_org_id(self, get_conn, mock_project_id):
result = self.hook.list_inspect_templates(organization_id=ORGANIZATION_ID)
- self.assertIsInstance(result, list)
+ assert isinstance(result, list)
get_conn.return_value.list_inspect_templates.assert_called_once_with(
parent=ORGANIZATION_PATH,
page_size=None,
@@ -823,7 +824,7 @@ def test_list_inspect_templates_with_org_id(self, get_conn, mock_project_id):
def test_list_inspect_templates_with_project_id(self, get_conn):
result = self.hook.list_inspect_templates(project_id=PROJECT_ID)
- self.assertIsInstance(result, list)
+ assert isinstance(result, list)
get_conn.return_value.list_inspect_templates.assert_called_once_with(
parent=PROJECT_PATH,
page_size=None,
@@ -840,14 +841,14 @@ def test_list_inspect_templates_with_project_id(self, get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_list_inspect_templates_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.list_inspect_templates()
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_list_job_triggers(self, get_conn):
result = self.hook.list_job_triggers(project_id=PROJECT_ID)
- self.assertIsInstance(result, list)
+ assert isinstance(result, list)
get_conn.return_value.list_job_triggers.assert_called_once_with(
parent=PROJECT_PATH,
page_size=None,
@@ -865,7 +866,7 @@ def test_list_job_triggers(self, get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_list_job_triggers_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.list_job_triggers() # pylint: disable=no-value-for-parameter
@mock.patch(
@@ -877,7 +878,7 @@ def test_list_job_triggers_without_parent(self, mock_get_conn, mock_project_id):
def test_list_stored_info_types_with_org_id(self, get_conn, mock_project_id):
result = self.hook.list_stored_info_types(organization_id=ORGANIZATION_ID)
- self.assertIsInstance(result, list)
+ assert isinstance(result, list)
get_conn.return_value.list_stored_info_types.assert_called_once_with(
parent=ORGANIZATION_PATH,
page_size=None,
@@ -891,7 +892,7 @@ def test_list_stored_info_types_with_org_id(self, get_conn, mock_project_id):
def test_list_stored_info_types_with_project_id(self, get_conn):
result = self.hook.list_stored_info_types(project_id=PROJECT_ID)
- self.assertIsInstance(result, list)
+ assert isinstance(result, list)
get_conn.return_value.list_stored_info_types.assert_called_once_with(
parent=PROJECT_PATH,
page_size=None,
@@ -908,7 +909,7 @@ def test_list_stored_info_types_with_project_id(self, get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_list_stored_info_types_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.list_stored_info_types()
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
@@ -916,7 +917,7 @@ def test_redact_image(self, get_conn):
get_conn.return_value.redact_image.return_value = API_RESPONSE
result = self.hook.redact_image(project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.redact_image.assert_called_once_with(
parent=PROJECT_PATH,
inspect_config=None,
@@ -935,7 +936,7 @@ def test_redact_image(self, get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_redact_image_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.redact_image() # pylint: disable=no-value-for-parameter
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
@@ -943,7 +944,7 @@ def test_reidentify_content(self, get_conn):
get_conn.return_value.reidentify_content.return_value = API_RESPONSE
result = self.hook.reidentify_content(project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.reidentify_content.assert_called_once_with(
parent=PROJECT_PATH,
reidentify_config=None,
@@ -963,7 +964,7 @@ def test_reidentify_content(self, get_conn):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_reidentify_content_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.reidentify_content() # pylint: disable=no-value-for-parameter
@mock.patch(
@@ -978,7 +979,7 @@ def test_update_deidentify_template_with_org_id(self, get_conn, mock_project_id)
template_id=TEMPLATE_ID, organization_id=ORGANIZATION_ID
)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.update_deidentify_template.assert_called_once_with(
name=DEIDENTIFY_TEMPLATE_ORGANIZATION_PATH,
deidentify_template=None,
@@ -993,7 +994,7 @@ def test_update_deidentify_template_with_project_id(self, get_conn):
get_conn.return_value.update_deidentify_template.return_value = API_RESPONSE
result = self.hook.update_deidentify_template(template_id=TEMPLATE_ID, project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.update_deidentify_template.assert_called_once_with(
name=DEIDENTIFY_TEMPLATE_PROJECT_PATH,
deidentify_template=None,
@@ -1005,7 +1006,7 @@ def test_update_deidentify_template_with_project_id(self, get_conn):
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_update_deidentify_template_without_template_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.update_deidentify_template(template_id=None, organization_id=ORGANIZATION_ID)
@mock.patch(
@@ -1015,7 +1016,7 @@ def test_update_deidentify_template_without_template_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_update_deidentify_template_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.update_deidentify_template(template_id=TEMPLATE_ID)
@mock.patch(
@@ -1028,7 +1029,7 @@ def test_update_inspect_template_with_org_id(self, get_conn, mock_project_id):
get_conn.return_value.update_inspect_template.return_value = API_RESPONSE
result = self.hook.update_inspect_template(template_id=TEMPLATE_ID, organization_id=ORGANIZATION_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.update_inspect_template.assert_called_once_with(
name=INSPECT_TEMPLATE_ORGANIZATION_PATH,
inspect_template=None,
@@ -1043,7 +1044,7 @@ def test_update_inspect_template_with_project_id(self, get_conn):
get_conn.return_value.update_inspect_template.return_value = API_RESPONSE
result = self.hook.update_inspect_template(template_id=TEMPLATE_ID, project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.update_inspect_template.assert_called_once_with(
name=INSPECT_TEMPLATE_PROJECT_PATH,
inspect_template=None,
@@ -1055,7 +1056,7 @@ def test_update_inspect_template_with_project_id(self, get_conn):
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_update_inspect_template_without_template_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.update_inspect_template(template_id=None, organization_id=ORGANIZATION_ID)
@mock.patch(
@@ -1065,7 +1066,7 @@ def test_update_inspect_template_without_template_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_update_inspect_template_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.update_inspect_template(template_id=TEMPLATE_ID)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
@@ -1073,7 +1074,7 @@ def test_update_job_trigger(self, get_conn):
get_conn.return_value.update_job_trigger.return_value = API_RESPONSE
result = self.hook.update_job_trigger(job_trigger_id=TRIGGER_ID, project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.update_job_trigger.assert_called_once_with(
name=JOB_TRIGGER_PATH,
job_trigger=None,
@@ -1085,7 +1086,7 @@ def test_update_job_trigger(self, get_conn):
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_update_job_trigger_without_job_trigger_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.update_job_trigger(job_trigger_id=None, project_id=PROJECT_ID)
@mock.patch(
@@ -1095,7 +1096,7 @@ def test_update_job_trigger_without_job_trigger_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_update_job_trigger_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.update_job_trigger(job_trigger_id=TRIGGER_ID) # pylint: disable=no-value-for-parameter
@mock.patch(
@@ -1110,7 +1111,7 @@ def test_update_stored_info_type_with_org_id(self, get_conn, mock_project_id):
stored_info_type_id=STORED_INFO_TYPE_ID, organization_id=ORGANIZATION_ID
)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.update_stored_info_type.assert_called_once_with(
name=STORED_INFO_TYPE_ORGANIZATION_PATH,
config=None,
@@ -1127,7 +1128,7 @@ def test_update_stored_info_type_with_project_id(self, get_conn):
stored_info_type_id=STORED_INFO_TYPE_ID, project_id=PROJECT_ID
)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.update_stored_info_type.assert_called_once_with(
name=STORED_INFO_TYPE_PROJECT_PATH,
config=None,
@@ -1139,7 +1140,7 @@ def test_update_stored_info_type_with_project_id(self, get_conn):
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_update_stored_info_type_without_stored_info_type_id(self, _):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.update_stored_info_type(stored_info_type_id=None, organization_id=ORGANIZATION_ID)
@mock.patch(
@@ -1149,5 +1150,5 @@ def test_update_stored_info_type_without_stored_info_type_id(self, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.dlp.CloudDLPHook.get_conn")
def test_update_stored_info_type_without_parent(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.hook.update_stored_info_type(stored_info_type_id=STORED_INFO_TYPE_ID)
diff --git a/tests/providers/google/cloud/hooks/test_functions.py b/tests/providers/google/cloud/hooks/test_functions.py
index 76304b00b3a07..e1b59356ceb2c 100644
--- a/tests/providers/google/cloud/hooks/test_functions.py
+++ b/tests/providers/google/cloud/hooks/test_functions.py
@@ -20,6 +20,8 @@
from unittest import mock
from unittest.mock import PropertyMock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.functions import CloudFunctionsHook
from tests.providers.google.cloud.utils.base_gcp_mock import (
@@ -48,8 +50,8 @@ def test_gcf_client_creation(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'cloudfunctions', 'v1', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(mock_build.return_value, result)
- self.assertEqual(self.gcf_function_hook_no_project_id._conn, result)
+ assert mock_build.return_value == result
+ assert self.gcf_function_hook_no_project_id._conn == result
@mock.patch('airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook.get_conn')
@mock.patch(
@@ -65,7 +67,7 @@ def test_create_new_function_overridden_project_id(self, wait_for_operation_to_c
res = self.gcf_function_hook_no_project_id.create_new_function(
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, location=GCF_LOCATION, body={}
)
- self.assertIsNone(res)
+ assert res is None
create_method.assert_called_once_with(body={}, location='projects/example-project/locations/location')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(operation_name='operation_id')
@@ -85,7 +87,7 @@ def test_upload_function_zip_overridden_project_id(self, get_conn, requests_put)
res = self.gcf_function_hook_no_project_id.upload_function_zip(
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, location=GCF_LOCATION, zip_path="/tmp/path.zip"
)
- self.assertEqual("http://uploadHere", res)
+ assert "http://uploadHere" == res
generate_upload_url_method.assert_called_once_with(
parent='projects/example-project/locations/location'
)
@@ -112,8 +114,8 @@ def test_gcf_client_creation(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'cloudfunctions', 'v1', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(mock_build.return_value, result)
- self.assertEqual(self.gcf_function_hook._conn, result)
+ assert mock_build.return_value == result
+ assert self.gcf_function_hook._conn == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -136,7 +138,7 @@ def test_create_new_function(self, wait_for_operation_to_complete, get_conn, moc
body={},
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
)
- self.assertIsNone(res)
+ assert res is None
create_method.assert_called_once_with(body={}, location='projects/example-project/locations/location')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(operation_name='operation_id')
@@ -155,7 +157,7 @@ def test_create_new_function_override_project_id(self, wait_for_operation_to_com
res = self.gcf_function_hook.create_new_function(
project_id='new-project', location=GCF_LOCATION, body={}
)
- self.assertIsNone(res)
+ assert res is None
create_method.assert_called_once_with(body={}, location='projects/new-project/locations/location')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(operation_name='operation_id')
@@ -168,8 +170,8 @@ def test_get_function(self, get_conn):
execute_method = get_method.return_value.execute
execute_method.return_value = {"name": "function"}
res = self.gcf_function_hook.get_function(name=GCF_FUNCTION)
- self.assertIsNotNone(res)
- self.assertEqual('function', res['name'])
+ assert res is not None
+ assert 'function' == res['name']
get_method.assert_called_once_with(name='function')
execute_method.assert_called_once_with(num_retries=5)
@@ -187,7 +189,7 @@ def test_delete_function(self, wait_for_operation_to_complete, get_conn):
res = self.gcf_function_hook.delete_function( # pylint: disable=assignment-from-no-return
name=GCF_FUNCTION
)
- self.assertIsNone(res)
+ assert res is None
delete_method.assert_called_once_with(name='function')
execute_method.assert_called_once_with(num_retries=5)
@@ -205,7 +207,7 @@ def test_update_function(self, wait_for_operation_to_complete, get_conn):
res = self.gcf_function_hook.update_function( # pylint: disable=assignment-from-no-return
update_mask=['a', 'b', 'c'], name=GCF_FUNCTION, body={}
)
- self.assertIsNone(res)
+ assert res is None
patch_method.assert_called_once_with(body={}, name='function', updateMask='a,b,c')
execute_method.assert_called_once_with(num_retries=5)
wait_for_operation_to_complete.assert_called_once_with(operation_name='operation_id')
@@ -232,7 +234,7 @@ def test_upload_function_zip(self, get_conn, requests_put, mock_project_id):
zip_path="/tmp/path.zip",
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
)
- self.assertEqual("http://uploadHere", res)
+ assert "http://uploadHere" == res
generate_upload_url_method.assert_called_once_with(
parent='projects/example-project/locations/location'
)
@@ -258,7 +260,7 @@ def test_upload_function_zip_overridden_project_id(self, get_conn, requests_put)
res = self.gcf_function_hook.upload_function_zip(
project_id='new-project', location=GCF_LOCATION, zip_path="/tmp/path.zip"
)
- self.assertEqual("http://uploadHere", res)
+ assert "http://uploadHere" == res
generate_upload_url_method.assert_called_once_with(
parent='projects/new-project/locations/location'
)
@@ -292,7 +294,7 @@ def test_call_function(self, mock_get_conn):
)
call.assert_called_once_with(body=input_data, name=name)
- self.assertDictEqual(result, payload)
+ assert result == payload
@mock.patch('airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook.get_conn')
def test_call_function_error(self, mock_get_conn):
@@ -305,7 +307,7 @@ def test_call_function_error(self, mock_get_conn):
function_id = "function1234"
input_data = {'key': 'value'}
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.gcf_function_hook.call_function(
function_id=function_id,
location=GCF_LOCATION,
diff --git a/tests/providers/google/cloud/hooks/test_gcs.py b/tests/providers/google/cloud/hooks/test_gcs.py
index 1ce44bb1ac9cf..5fe96ed3bcfd1 100644
--- a/tests/providers/google/cloud/hooks/test_gcs.py
+++ b/tests/providers/google/cloud/hooks/test_gcs.py
@@ -26,6 +26,7 @@
from unittest import mock
import dateutil
+import pytest
from google.cloud import exceptions, storage
from airflow.exceptions import AirflowException
@@ -48,17 +49,19 @@ def test_parse_gcs_url(self):
Test GCS url parsing
"""
- self.assertEqual(gcs._parse_gcs_url('gs://bucket/path/to/blob'), ('bucket', 'path/to/blob'))
+ assert gcs._parse_gcs_url('gs://bucket/path/to/blob') == ('bucket', 'path/to/blob')
# invalid URI
- self.assertRaises(AirflowException, gcs._parse_gcs_url, 'gs:/bucket/path/to/blob')
- self.assertRaises(AirflowException, gcs._parse_gcs_url, 'http://google.com/aaa')
+ with pytest.raises(AirflowException):
+ gcs._parse_gcs_url('gs:/bucket/path/to/blob')
+ with pytest.raises(AirflowException):
+ gcs._parse_gcs_url('http://google.com/aaa')
# trailing slash
- self.assertEqual(gcs._parse_gcs_url('gs://bucket/path/to/blob/'), ('bucket', 'path/to/blob/'))
+ assert gcs._parse_gcs_url('gs://bucket/path/to/blob/') == ('bucket', 'path/to/blob/')
# bucket only
- self.assertEqual(gcs._parse_gcs_url('gs://bucket/'), ('bucket', ''))
+ assert gcs._parse_gcs_url('gs://bucket/') == ('bucket', '')
class TestFallbackObjectUrlToObjectNameAndBucketName(unittest.TestCase):
@@ -83,9 +86,9 @@ def test_should_support_bucket_and_object(self):
self.assertion_on_body.assert_called_once()
def test_should_raise_exception_on_missing(self):
- with self.assertRaisesRegex(
+ with pytest.raises(
TypeError,
- re.escape(
+ match=re.escape(
"test_method() missing 2 required positional arguments: 'bucket_name' and 'object_name'"
),
):
@@ -93,7 +96,7 @@ def test_should_raise_exception_on_missing(self):
self.assertion_on_body.assert_not_called()
def test_should_raise_exception_on_mutually_exclusive(self):
- with self.assertRaisesRegex(AirflowException, re.escape("The mutually exclusive parameters.")):
+ with pytest.raises(AirflowException, match=re.escape("The mutually exclusive parameters.")):
self.test_method(
None,
bucket_name="BUCKET_NAME",
@@ -131,7 +134,7 @@ def test_storage_client_creation(
mock_client.assert_called_once_with(
client_info="CLIENT_INFO", credentials="CREDENTIALS", project="PROJECT_ID"
)
- self.assertEqual(mock_client.return_value, result)
+ assert mock_client.return_value == result
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_exists(self, mock_service):
@@ -148,7 +151,7 @@ def test_exists(self, mock_service):
response = self.gcs_hook.exists(bucket_name=test_bucket, object_name=test_object)
# Then
- self.assertTrue(response)
+ assert response
bucket_mock.assert_called_once_with(test_bucket)
blob_object.assert_called_once_with(blob_name=test_object)
exists_method.assert_called_once_with()
@@ -168,7 +171,7 @@ def test_exists_nonexisting_object(self, mock_service):
response = self.gcs_hook.exists(bucket_name=test_bucket, object_name=test_object)
# Then
- self.assertFalse(response)
+ assert not response
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_is_updated_after(self, mock_service):
@@ -186,7 +189,7 @@ def test_is_updated_after(self, mock_service):
)
# Then
- self.assertTrue(response)
+ assert response
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_is_updated_before(self, mock_service):
@@ -204,7 +207,7 @@ def test_is_updated_before(self, mock_service):
)
# Then
- self.assertTrue(response)
+ assert response
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_is_updated_between(self, mock_service):
@@ -225,7 +228,7 @@ def test_is_updated_between(self, mock_service):
)
# Then
- self.assertTrue(response)
+ assert response
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_is_older_than_with_true_cond(self, mock_service):
@@ -243,7 +246,7 @@ def test_is_older_than_with_true_cond(self, mock_service):
)
# Then
- self.assertTrue(response)
+ assert response
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_is_older_than_with_false_cond(self, mock_service):
@@ -260,7 +263,7 @@ def test_is_older_than_with_false_cond(self, mock_service):
bucket_name=test_bucket, object_name=test_object, seconds=86400 # 24hr
)
# Then
- self.assertFalse(response)
+ assert not response
@mock.patch('google.cloud.storage.Bucket')
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
@@ -289,7 +292,7 @@ def test_copy(self, mock_service, mock_bucket):
)
# Then
- self.assertEqual(response, None)
+ assert response is None
copy_method.assert_called_once_with(
blob=source_blob, destination_bucket=destination_bucket_instance, new_name=destination_object
)
@@ -300,7 +303,7 @@ def test_copy_fail_same_source_and_destination(self):
destination_bucket = 'test-source-bucket'
destination_object = 'test-source-object'
- with self.assertRaises(ValueError) as e:
+ with pytest.raises(ValueError) as ctx:
self.gcs_hook.copy(
source_bucket=source_bucket,
source_object=source_object,
@@ -308,10 +311,12 @@ def test_copy_fail_same_source_and_destination(self):
destination_object=destination_object,
)
- self.assertEqual(
- str(e.exception),
- 'Either source/destination bucket or source/destination object '
- 'must be different, not both the same: bucket=%s, object=%s' % (source_bucket, source_object),
+ assert str(ctx.value) == (
+ 'Either source/destination bucket or source/destination object must be different, '
+ 'not both the same: bucket={}, object={}'
+ ).format(
+ source_bucket,
+ source_object,
)
def test_copy_empty_source_bucket(self):
@@ -320,7 +325,7 @@ def test_copy_empty_source_bucket(self):
destination_bucket = 'test-dest-bucket'
destination_object = 'test-dest-object'
- with self.assertRaises(ValueError) as e:
+ with pytest.raises(ValueError) as ctx:
self.gcs_hook.copy(
source_bucket=source_bucket,
source_object=source_object,
@@ -328,7 +333,7 @@ def test_copy_empty_source_bucket(self):
destination_object=destination_object,
)
- self.assertEqual(str(e.exception), 'source_bucket and source_object cannot be empty.')
+ assert str(ctx.value) == 'source_bucket and source_object cannot be empty.'
def test_copy_empty_source_object(self):
source_bucket = 'test-source-object'
@@ -336,7 +341,7 @@ def test_copy_empty_source_object(self):
destination_bucket = 'test-dest-bucket'
destination_object = 'test-dest-object'
- with self.assertRaises(ValueError) as e:
+ with pytest.raises(ValueError) as ctx:
self.gcs_hook.copy(
source_bucket=source_bucket,
source_object=source_object,
@@ -344,7 +349,7 @@ def test_copy_empty_source_object(self):
destination_object=destination_object,
)
- self.assertEqual(str(e.exception), 'source_bucket and source_object cannot be empty.')
+ assert str(ctx.value) == 'source_bucket and source_object cannot be empty.'
@mock.patch('google.cloud.storage.Bucket')
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
@@ -372,7 +377,7 @@ def test_rewrite(self, mock_service, mock_bucket):
)
# Then
- self.assertEqual(response, None)
+ assert response is None
rewrite_method.assert_called_once_with(source=source_blob)
def test_rewrite_empty_source_bucket(self):
@@ -381,7 +386,7 @@ def test_rewrite_empty_source_bucket(self):
destination_bucket = 'test-dest-bucket'
destination_object = 'test-dest-object'
- with self.assertRaises(ValueError) as e:
+ with pytest.raises(ValueError) as ctx:
self.gcs_hook.rewrite(
source_bucket=source_bucket,
source_object=source_object,
@@ -389,7 +394,7 @@ def test_rewrite_empty_source_bucket(self):
destination_object=destination_object,
)
- self.assertEqual(str(e.exception), 'source_bucket and source_object cannot be empty.')
+ assert str(ctx.value) == 'source_bucket and source_object cannot be empty.'
def test_rewrite_empty_source_object(self):
source_bucket = 'test-source-object'
@@ -397,7 +402,7 @@ def test_rewrite_empty_source_object(self):
destination_bucket = 'test-dest-bucket'
destination_object = 'test-dest-object'
- with self.assertRaises(ValueError) as e:
+ with pytest.raises(ValueError) as ctx:
self.gcs_hook.rewrite(
source_bucket=source_bucket,
source_object=source_object,
@@ -405,7 +410,7 @@ def test_rewrite_empty_source_object(self):
destination_object=destination_object,
)
- self.assertEqual(str(e.exception), 'source_bucket and source_object cannot be empty.')
+ assert str(ctx.value) == 'source_bucket and source_object cannot be empty.'
@mock.patch('google.cloud.storage.Bucket')
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
@@ -422,7 +427,7 @@ def test_delete(self, mock_service, mock_bucket):
response = self.gcs_hook.delete( # pylint: disable=assignment-from-no-return
bucket_name=test_bucket, object_name=test_object
)
- self.assertIsNone(response)
+ assert response is None
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_delete_nonexisting_object(self, mock_service):
@@ -434,7 +439,7 @@ def test_delete_nonexisting_object(self, mock_service):
delete_method = blob.return_value.delete
delete_method.side_effect = exceptions.NotFound(message="Not Found")
- with self.assertRaises(exceptions.NotFound):
+ with pytest.raises(exceptions.NotFound):
self.gcs_hook.delete(bucket_name=test_bucket, object_name=test_object)
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
@@ -470,7 +475,7 @@ def test_object_get_size(self, mock_service):
response = self.gcs_hook.get_size(bucket_name=test_bucket, object_name=test_object)
- self.assertEqual(response, returned_file_size)
+ assert response == returned_file_size
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_object_get_crc32c(self, mock_service):
@@ -484,7 +489,7 @@ def test_object_get_crc32c(self, mock_service):
response = self.gcs_hook.get_crc32c(bucket_name=test_bucket, object_name=test_object)
- self.assertEqual(response, returned_file_crc32c)
+ assert response == returned_file_crc32c
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_object_get_md5hash(self, mock_service):
@@ -498,7 +503,7 @@ def test_object_get_md5hash(self, mock_service):
response = self.gcs_hook.get_md5hash(bucket_name=test_bucket, object_name=test_object)
- self.assertEqual(response, returned_file_md5hash)
+ assert response == returned_file_md5hash
@mock.patch('google.cloud.storage.Bucket')
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
@@ -526,10 +531,10 @@ def test_create_bucket(self, mock_service, mock_bucket):
project_id=test_project,
)
- self.assertEqual(response, sample_bucket.id)
+ assert response == sample_bucket.id
- self.assertEqual(sample_bucket.storage_class, test_storage_class)
- self.assertDictEqual(sample_bucket.labels, test_labels)
+ assert sample_bucket.storage_class == test_storage_class
+ assert sample_bucket.labels == test_labels
mock_service.return_value.bucket.return_value.create.assert_called_once_with(
project=test_project, location=test_location
@@ -562,7 +567,7 @@ def test_create_bucket_with_resource(self, mock_service, mock_bucket):
labels=test_labels,
project_id=test_project,
)
- self.assertEqual(response, sample_bucket.id)
+ assert response == sample_bucket.id
mock_service.return_value.bucket.return_value._patch_property.assert_called_once_with(
name='versioning', value=test_versioning_enabled
@@ -598,14 +603,14 @@ def test_compose_with_empty_source_objects(self, mock_service): # pylint: disab
test_source_objects = []
test_destination_object = 'test_object_composed'
- with self.assertRaises(ValueError) as e:
+ with pytest.raises(ValueError) as ctx:
self.gcs_hook.compose(
bucket_name=test_bucket,
source_objects=test_source_objects,
destination_object=test_destination_object,
)
- self.assertEqual(str(e.exception), 'source_objects cannot be empty.')
+ assert str(ctx.value) == 'source_objects cannot be empty.'
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_compose_without_bucket(self, mock_service): # pylint: disable=unused-argument
@@ -613,14 +618,14 @@ def test_compose_without_bucket(self, mock_service): # pylint: disable=unused-a
test_source_objects = ['test_object_1', 'test_object_2', 'test_object_3']
test_destination_object = 'test_object_composed'
- with self.assertRaises(ValueError) as e:
+ with pytest.raises(ValueError) as ctx:
self.gcs_hook.compose(
bucket_name=test_bucket,
source_objects=test_source_objects,
destination_object=test_destination_object,
)
- self.assertEqual(str(e.exception), 'bucket_name and destination_object cannot be empty.')
+ assert str(ctx.value) == 'bucket_name and destination_object cannot be empty.'
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_compose_without_destination_object(self, mock_service): # pylint: disable=unused-argument
@@ -628,14 +633,14 @@ def test_compose_without_destination_object(self, mock_service): # pylint: disa
test_source_objects = ['test_object_1', 'test_object_2', 'test_object_3']
test_destination_object = None
- with self.assertRaises(ValueError) as e:
+ with pytest.raises(ValueError) as ctx:
self.gcs_hook.compose(
bucket_name=test_bucket,
source_objects=test_source_objects,
destination_object=test_destination_object,
)
- self.assertEqual(str(e.exception), 'bucket_name and destination_object cannot be empty.')
+ assert str(ctx.value) == 'bucket_name and destination_object cannot be empty.'
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_download_as_string(self, mock_service):
@@ -648,7 +653,7 @@ def test_download_as_string(self, mock_service):
response = self.gcs_hook.download(bucket_name=test_bucket, object_name=test_object, filename=None)
- self.assertEqual(response, test_object_bytes)
+ assert response == test_object_bytes
download_method.assert_called_once_with()
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
@@ -671,7 +676,7 @@ def test_download_to_file(self, mock_service):
bucket_name=test_bucket, object_name=test_object, filename=test_file
)
- self.assertEqual(response, test_file)
+ assert response == test_file
download_filename_method.assert_called_once_with(test_file, timeout=60)
@mock.patch(GCS_STRING.format('NamedTemporaryFile'))
@@ -696,7 +701,7 @@ def test_provide_file(self, mock_service, mock_temp_file):
with self.gcs_hook.provide_file(bucket_name=test_bucket, object_name=test_object) as response:
- self.assertEqual(test_file, response.name)
+ assert test_file == response.name
download_filename_method.assert_called_once_with(test_file, timeout=60)
mock_temp_file.assert_has_calls(
[
@@ -771,7 +776,7 @@ def test_upload_file_gzip(self, mock_service):
test_object = 'test_object'
self.gcs_hook.upload(test_bucket, test_object, filename=self.testfile.name, gzip=True)
- self.assertFalse(os.path.exists(self.testfile.name + '.gz'))
+ assert not os.path.exists(self.testfile.name + '.gz')
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_upload_data_str(self, mock_service):
@@ -842,15 +847,15 @@ def test_upload_exceptions(self, mock_service):
)
no_params_excep = "'filename' and 'data' parameter missing. One is required to upload to gcs."
- with self.assertRaises(ValueError) as cm:
+ with pytest.raises(ValueError) as ctx:
self.gcs_hook.upload(test_bucket, test_object)
- self.assertEqual(no_params_excep, str(cm.exception))
+ assert no_params_excep == str(ctx.value)
- with self.assertRaises(ValueError) as cm:
+ with pytest.raises(ValueError) as ctx:
self.gcs_hook.upload(
test_bucket, test_object, filename=self.testfile.name, data=self.testdata_str
)
- self.assertEqual(both_params_excep, str(cm.exception))
+ assert both_params_excep == str(ctx.value)
class TestSyncGcsHook(unittest.TestCase):
diff --git a/tests/providers/google/cloud/hooks/test_gdm.py b/tests/providers/google/cloud/hooks/test_gdm.py
index 5f525bf5fabde..23b7b280d912d 100644
--- a/tests/providers/google/cloud/hooks/test_gdm.py
+++ b/tests/providers/google/cloud/hooks/test_gdm.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.gdm import GoogleDeploymentManagerHook
@@ -71,12 +73,12 @@ def test_list_deployments(self, mock_get_conn):
orderBy='name',
)
- self.assertEqual(mock_get_conn.return_value.deployments.return_value.list_next.call_count, 2)
+ assert mock_get_conn.return_value.deployments.return_value.list_next.call_count == 2
- self.assertEqual(
- deployments,
- [{'id': 'deployment1', 'name': 'test-deploy1'}, {'id': 'deployment2', 'name': 'test-deploy2'}],
- )
+ assert deployments == [
+ {'id': 'deployment1', 'name': 'test-deploy1'},
+ {'id': 'deployment2', 'name': 'test-deploy2'},
+ ]
@mock.patch("airflow.providers.google.cloud.hooks.gdm.GoogleDeploymentManagerHook.get_conn")
def test_delete_deployment(self, mock_get_conn):
@@ -93,7 +95,7 @@ def test_delete_deployment_delete_fails(self, mock_get_conn):
mock_get_conn.return_value.deployments.return_value.delete.return_value.execute.return_value = resp
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.gdm_hook.delete_deployment(project_id=TEST_PROJECT, deployment=TEST_DEPLOYMENT)
mock_get_conn.assert_called_once_with()
diff --git a/tests/providers/google/cloud/hooks/test_kms.py b/tests/providers/google/cloud/hooks/test_kms.py
index d2f4519df5534..4de1dfbc6e91b 100644
--- a/tests/providers/google/cloud/hooks/test_kms.py
+++ b/tests/providers/google/cloud/hooks/test_kms.py
@@ -73,8 +73,8 @@ def test_kms_client_creation(self, mock_client, mock_get_creds, mock_client_info
credentials=mock_get_creds.return_value,
client_info=mock_client_info.return_value,
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.kms_hook._conn, result)
+ assert mock_client.return_value == result
+ assert self.kms_hook._conn == result
@mock.patch("airflow.providers.google.cloud.hooks.kms.CloudKMSHook.get_conn")
def test_encrypt(self, mock_get_conn):
@@ -91,7 +91,7 @@ def test_encrypt(self, mock_get_conn):
timeout=None,
metadata=(),
)
- self.assertEqual(PLAINTEXT_b64, result)
+ assert PLAINTEXT_b64 == result
@mock.patch("airflow.providers.google.cloud.hooks.kms.CloudKMSHook.get_conn")
def test_encrypt_with_auth_data(self, mock_get_conn):
@@ -108,7 +108,7 @@ def test_encrypt_with_auth_data(self, mock_get_conn):
timeout=None,
metadata=(),
)
- self.assertEqual(PLAINTEXT_b64, result)
+ assert PLAINTEXT_b64 == result
@mock.patch("airflow.providers.google.cloud.hooks.kms.CloudKMSHook.get_conn")
def test_decrypt(self, mock_get_conn):
@@ -125,7 +125,7 @@ def test_decrypt(self, mock_get_conn):
timeout=None,
metadata=(),
)
- self.assertEqual(PLAINTEXT, result)
+ assert PLAINTEXT == result
@mock.patch("airflow.providers.google.cloud.hooks.kms.CloudKMSHook.get_conn")
def test_decrypt_with_auth_data(self, mock_get_conn):
@@ -142,4 +142,4 @@ def test_decrypt_with_auth_data(self, mock_get_conn):
timeout=None,
metadata=(),
)
- self.assertEqual(PLAINTEXT, result)
+ assert PLAINTEXT == result
diff --git a/tests/providers/google/cloud/hooks/test_kms_system.py b/tests/providers/google/cloud/hooks/test_kms_system.py
index 5c77ed859f928..6963430a90b84 100644
--- a/tests/providers/google/cloud/hooks/test_kms_system.py
+++ b/tests/providers/google/cloud/hooks/test_kms_system.py
@@ -66,7 +66,7 @@ def test_encrypt(self):
)
with open(f"{tmp_dir}/mysecret.txt", "rb") as secret_file:
secret = secret_file.read()
- self.assertEqual(secret, b"TEST-SECRET")
+ assert secret == b"TEST-SECRET"
@provide_gcp_context(GCP_KMS_KEY)
def test_decrypt(self):
@@ -101,4 +101,4 @@ def test_decrypt(self):
),
ciphertext=encrypted_secret,
)
- self.assertEqual(content, b"TEST-SECRET")
+ assert content == b"TEST-SECRET"
diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
index 1040ec4cefb11..efbe346279fad 100644
--- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
@@ -20,6 +20,7 @@
from unittest import mock
from unittest.mock import PropertyMock
+import pytest
from google.cloud.container_v1.types import Cluster
from airflow.exceptions import AirflowException
@@ -47,8 +48,8 @@ def test_gke_cluster_client_creation(self, mock_client, mock_get_creds, mock_cli
mock_client.assert_called_once_with(
credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.gke_hook._client, result)
+ assert mock_client.return_value == result
+ assert self.gke_hook._client == result
class TestGKEHookDelete(unittest.TestCase):
@@ -113,7 +114,7 @@ def test_delete_cluster_error(self, wait_mock, convert_mock, mock_project_id):
# To force an error
self.gke_hook._client.delete_cluster.side_effect = AirflowException('400')
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.gke_hook.delete_cluster(name='a-cluster') # pylint: disable=no-value-for-parameter
wait_mock.assert_not_called()
convert_mock.assert_not_called()
@@ -187,7 +188,7 @@ def test_create_cluster_error(self, wait_mock, convert_mock):
# to force an error
mock_cluster_proto = None
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.gke_hook.create_cluster(mock_cluster_proto) # pylint: disable=no-value-for-parameter
wait_mock.assert_not_called()
convert_mock.assert_not_called()
@@ -208,7 +209,7 @@ def test_create_cluster_already_exists(self, wait_mock, convert_mock, log_mock,
self.gke_hook.create_cluster(cluster={}, project_id=TEST_GCP_PROJECT_ID)
wait_mock.assert_not_called()
- self.assertEqual(convert_mock.call_count, 1)
+ assert convert_mock.call_count == 1
log_mock.info.assert_any_call("Assuming Success: %s", message)
@@ -279,7 +280,7 @@ def test_wait_for_response_done(self, time_mock):
mock_op = mock.Mock()
mock_op.status = Operation.Status.DONE
self.gke_hook.wait_for_operation(mock_op)
- self.assertEqual(time_mock.call_count, 1)
+ assert time_mock.call_count == 1
@mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.time.sleep")
def test_wait_for_response_exception(self, time_mock):
@@ -289,9 +290,9 @@ def test_wait_for_response_exception(self, time_mock):
mock_op = mock.Mock()
mock_op.status = Operation.Status.ABORTING
- with self.assertRaises(GoogleCloudError):
+ with pytest.raises(GoogleCloudError):
self.gke_hook.wait_for_operation(mock_op)
- self.assertEqual(time_mock.call_count, 1)
+ assert time_mock.call_count == 1
@mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook.get_operation")
@mock.patch("airflow.providers.google.cloud.hooks.kubernetes_engine.time.sleep")
@@ -307,7 +308,7 @@ def test_wait_for_response_running(self, time_mock, operation_mock):
operation_mock.side_effect = [pending_op, done_op]
self.gke_hook.wait_for_operation(running_op, project_id=TEST_GCP_PROJECT_ID)
- self.assertEqual(time_mock.call_count, 3)
+ assert time_mock.call_count == 3
operation_mock.assert_any_call(running_op.name, project_id=TEST_GCP_PROJECT_ID)
operation_mock.assert_any_call(pending_op.name, project_id=TEST_GCP_PROJECT_ID)
- self.assertEqual(operation_mock.call_count, 2)
+ assert operation_mock.call_count == 2
diff --git a/tests/providers/google/cloud/hooks/test_life_sciences.py b/tests/providers/google/cloud/hooks/test_life_sciences.py
index e203e87f3eb79..a071802333403 100644
--- a/tests/providers/google/cloud/hooks/test_life_sciences.py
+++ b/tests/providers/google/cloud/hooks/test_life_sciences.py
@@ -22,6 +22,8 @@
from unittest import mock
from unittest.mock import PropertyMock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.life_sciences import LifeSciencesHook
from tests.providers.google.cloud.utils.base_gcp_mock import (
@@ -55,7 +57,7 @@ def setUp(self):
def test_location_path(self):
path = 'projects/life-science-project-id/locations/test-location'
path2 = self.hook._location_path(project_id=TEST_PROJECT_ID, location=TEST_LOCATION)
- self.assertEqual(path, path2)
+ assert path == path2
@mock.patch("airflow.providers.google.cloud.hooks.life_sciences.LifeSciencesHook._authorize")
@mock.patch("airflow.providers.google.cloud.hooks.life_sciences.build")
@@ -64,8 +66,8 @@ def test_life_science_client_creation(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'lifesciences', 'v2beta', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(mock_build.return_value, result)
- self.assertEqual(self.hook._conn, result)
+ assert mock_build.return_value == result
+ assert self.hook._conn == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -97,7 +99,7 @@ def test_run_pipeline_immediately_complete(self, get_conn_mock, mock_project_id)
.assert_called_once_with(body={},
parent=parent)
# fmt: on
- self.assertEqual(result, TEST_OPERATION)
+ assert result == TEST_OPERATION
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -126,7 +128,7 @@ def test_waiting_operation(self, _, get_conn_mock, mock_project_id):
# fmt: on
result = self.hook.run_pipeline(body={}, location=TEST_LOCATION, project_id=TEST_PROJECT_ID)
- self.assertEqual(result, TEST_OPERATION)
+ assert result == TEST_OPERATION
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -151,7 +153,7 @@ def test_error_operation(self, _, get_conn_mock, mock_project_id):
.get.return_value \
.execute = execute_mock
# fmt: on
- with self.assertRaisesRegex(AirflowException, "error"):
+ with pytest.raises(AirflowException, match="error"):
self.hook.run_pipeline(body={}, location=TEST_LOCATION, project_id=TEST_PROJECT_ID)
@@ -170,8 +172,8 @@ def test_life_science_client_creation(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'lifesciences', 'v2beta', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(mock_build.return_value, result)
- self.assertEqual(self.hook._conn, result)
+ assert mock_build.return_value == result
+ assert self.hook._conn == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -202,7 +204,7 @@ def test_run_pipeline_immediately_complete(self, get_conn_mock, mock_project_id)
.assert_called_once_with(body={},
parent=parent)
# fmt: on
- self.assertEqual(result, TEST_OPERATION)
+ assert result == TEST_OPERATION
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -231,7 +233,7 @@ def test_waiting_operation(self, _, get_conn_mock, mock_project_id):
# pylint: disable=no-value-for-parameter
result = self.hook.run_pipeline(body={}, location=TEST_LOCATION)
- self.assertEqual(result, TEST_OPERATION)
+ assert result == TEST_OPERATION
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -258,7 +260,7 @@ def test_error_operation(self, _, get_conn_mock, mock_project_id):
.execute = execute_mock
# fmt: on
- with self.assertRaisesRegex(AirflowException, "error"):
+ with pytest.raises(AirflowException, match="error"):
self.hook.run_pipeline(body={}, location=TEST_LOCATION) # pylint: disable=no-value-for-parameter
@@ -277,8 +279,8 @@ def test_life_science_client_creation(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'lifesciences', 'v2beta', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(mock_build.return_value, result)
- self.assertEqual(self.hook._conn, result)
+ assert mock_build.return_value == result
+ assert self.hook._conn == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -287,11 +289,10 @@ def test_life_science_client_creation(self, mock_build, mock_authorize):
)
@mock.patch("airflow.providers.google.cloud.hooks.life_sciences.LifeSciencesHook.get_conn")
def test_run_pipeline(self, get_conn_mock, mock_project_id): # pylint: disable=unused-argument
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.hook.run_pipeline(body={}, location=TEST_LOCATION) # pylint: disable=no-value-for-parameter
- self.assertEqual(
+ assert (
"The project id must be passed either as keyword project_id parameter or as project_id extra in "
- "Google Cloud connection definition. Both are not set!",
- str(e.exception),
+ "Google Cloud connection definition. Both are not set!" == str(ctx.value)
)
diff --git a/tests/providers/google/cloud/hooks/test_mlengine.py b/tests/providers/google/cloud/hooks/test_mlengine.py
index f8381e829ff97..a50b05aa43a8c 100644
--- a/tests/providers/google/cloud/hooks/test_mlengine.py
+++ b/tests/providers/google/cloud/hooks/test_mlengine.py
@@ -21,6 +21,7 @@
from unittest.mock import PropertyMock
import httplib2
+import pytest
from googleapiclient.errors import HttpError
from airflow.providers.google.cloud.hooks import mlengine as hook
@@ -40,7 +41,7 @@ def setUp(self) -> None:
def test_mle_engine_client_creation(self, mock_build, mock_authorize):
result = self.hook.get_conn()
- self.assertEqual(mock_build.return_value, result)
+ assert mock_build.return_value == result
mock_build.assert_called_with('ml', 'v1', http=mock_authorize.return_value, cache_discovery=False)
@mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn")
@@ -77,7 +78,7 @@ def test_create_version(self, mock_get_conn):
project_id=project_id, model_name=model_name, version_spec=deepcopy(version)
)
- self.assertEqual(create_version_response, operation_done)
+ assert create_version_response == operation_done
mock_get_conn.assert_has_calls(
[
@@ -129,7 +130,7 @@ def test_create_version_with_labels(self, mock_get_conn):
)
# fmt: on
- self.assertEqual(create_version_response, operation_done)
+ assert create_version_response == operation_done
mock_get_conn.assert_has_calls(
[
@@ -166,7 +167,7 @@ def test_set_default_version(self, mock_get_conn):
project_id=project_id, model_name=model_name, version_name=version_name
)
- self.assertEqual(set_default_version_response, operation_done)
+ assert set_default_version_response == operation_done
mock_get_conn.assert_has_calls(
[
@@ -203,7 +204,7 @@ def test_list_versions(self, mock_get_conn, mock_sleep):
list_versions_response = self.hook.list_versions(
project_id=project_id, model_name=model_name)
# fmt: on
- self.assertEqual(list_versions_response, version_names)
+ assert list_versions_response == version_names
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().versions().list(pageSize=100, parent=model_path),
@@ -252,7 +253,7 @@ def test_delete_version(self, mock_get_conn):
project_id=project_id, model_name=model_name, version_name=version_name
)
- self.assertEqual(delete_version_response, operation_done)
+ assert delete_version_response == operation_done
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().versions().delete(name=version_path),
@@ -286,7 +287,7 @@ def test_create_model(self, mock_get_conn):
# fmt: on
create_model_response = self.hook.create_model(project_id=project_id, model=deepcopy(model))
- self.assertEqual(create_model_response, model)
+ assert create_model_response == model
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path),
@@ -349,7 +350,7 @@ def test_create_model_idempotency(self, mock_get_conn):
# fmt: on
create_model_response = self.hook.create_model(project_id=project_id, model=deepcopy(model))
- self.assertEqual(create_model_response, model)
+ assert create_model_response == model
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path),
@@ -386,7 +387,7 @@ def test_create_model_with_labels(self, mock_get_conn):
project_id=project_id, model=deepcopy(model)
)
# fmt: on
- self.assertEqual(create_model_response, model)
+ assert create_model_response == model
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path),
@@ -411,7 +412,7 @@ def test_get_model(self, mock_get_conn):
# fmt: on
get_model_response = self.hook.get_model(project_id=project_id, model_name=model_name)
- self.assertEqual(get_model_response, model)
+ assert get_model_response == model
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().get(name=model_path),
@@ -574,7 +575,7 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep):
# fmt: on
create_job_response = self.hook.create_job(project_id=project_id, job=deepcopy(new_job))
- self.assertEqual(create_job_response, job_succeeded)
+ assert create_job_response == job_succeeded
mock_get_conn.assert_has_calls(
[
mock.call().projects().jobs().create(body=new_job_with_airflow_version, parent=project_path),
@@ -626,7 +627,7 @@ def test_create_mlengine_job_with_labels(self, mock_get_conn, mock_sleep):
project_id=project_id, job=deepcopy(new_job)
)
# fmt: on
- self.assertEqual(create_job_response, job_succeeded)
+ assert create_job_response == job_succeeded
mock_get_conn.assert_has_calls(
[
mock.call().projects().jobs().create(body=new_job_with_airflow_version, parent=project_path),
@@ -666,7 +667,7 @@ def test_create_mlengine_job_reuse_existing_job_by_default(self, mock_get_conn):
# fmt: on
create_job_response = self.hook.create_job(project_id=project_id, job=job_succeeded)
- self.assertEqual(create_job_response, job_succeeded)
+ assert create_job_response == job_succeeded
mock_get_conn.assert_has_calls(
[
mock.call().projects().jobs().create(body=job_succeeded, parent=project_path),
@@ -715,7 +716,7 @@ def test_create_mlengine_job_check_existing_job_failed(self, mock_get_conn):
def check_input(existing_job):
return existing_job.get('someInput') == my_job['someInput']
- with self.assertRaises(HttpError):
+ with pytest.raises(HttpError):
self.hook.create_job(project_id=project_id, job=my_job, use_existing_job_fn=check_input)
@mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn")
@@ -753,7 +754,7 @@ def check_input(existing_job):
project_id=project_id, job=my_job, use_existing_job_fn=check_input
)
- self.assertEqual(create_job_response, my_job)
+ assert create_job_response == my_job
@mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn")
def test_cancel_mlengine_job(self, mock_get_conn):
@@ -773,7 +774,7 @@ def test_cancel_mlengine_job(self, mock_get_conn):
# fmt: on
cancel_job_response = self.hook.cancel_job(job_id=job_id, project_id=project_id)
- self.assertEqual(cancel_job_response, job_cancelled)
+ assert cancel_job_response == job_cancelled
mock_get_conn.assert_has_calls(
[
mock.call().projects().jobs().cancel(name=job_path),
@@ -805,7 +806,7 @@ def test_cancel_mlengine_job_nonexistent_job(self, mock_get_conn):
execute.return_value
) = job_cancelled
# fmt: on
- with self.assertRaises(HttpError):
+ with pytest.raises(HttpError):
self.hook.cancel_job(job_id=job_id, project_id=project_id)
@mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn")
@@ -836,7 +837,7 @@ def test_cancel_mlengine_job_completed_job(self, mock_get_conn):
# fmt: on
cancel_job_response = self.hook.cancel_job(job_id=job_id, project_id=project_id)
- self.assertEqual(cancel_job_response, job_cancelled)
+ assert cancel_job_response == job_cancelled
mock_get_conn.assert_has_calls(
[
mock.call().projects().jobs().cancel(name=job_path),
@@ -889,7 +890,7 @@ def test_create_version(self, mock_get_conn, mock_project_id):
model_name=model_name, version_spec=version, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST
)
- self.assertEqual(create_version_response, operation_done)
+ assert create_version_response == operation_done
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().versions().create(body=version, parent=model_path),
@@ -930,7 +931,7 @@ def test_set_default_version(self, mock_get_conn, mock_project_id):
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
)
- self.assertEqual(set_default_version_response, operation_done)
+ assert set_default_version_response == operation_done
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().versions().setDefault(body={}, name=version_path),
@@ -971,7 +972,7 @@ def test_list_versions(self, mock_get_conn, mock_sleep, mock_project_id):
model_name=model_name, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST
)
- self.assertEqual(list_versions_response, version_names)
+ assert list_versions_response == version_names
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().versions().list(pageSize=100, parent=model_path),
@@ -1029,7 +1030,7 @@ def test_delete_version(self, mock_get_conn, mock_project_id):
model_name=model_name, version_name=version_name, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST
)
- self.assertEqual(delete_version_response, operation_done)
+ assert delete_version_response == operation_done
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().versions().delete(name=version_path),
@@ -1063,7 +1064,7 @@ def test_create_model(self, mock_get_conn, mock_project_id):
# fmt: on
create_model_response = self.hook.create_model(model=model, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST)
- self.assertEqual(create_model_response, model)
+ assert create_model_response == model
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().create(body=model, parent=project_path),
@@ -1094,7 +1095,7 @@ def test_get_model(self, mock_get_conn, mock_project_id):
model_name=model_name, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST
)
- self.assertEqual(get_model_response, model)
+ assert get_model_response == model
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().get(name=model_path),
@@ -1171,7 +1172,7 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep, mock_project_id):
# fmt: on
create_job_response = self.hook.create_job(job=new_job, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST)
- self.assertEqual(create_job_response, job_succeeded)
+ assert create_job_response == job_succeeded
mock_get_conn.assert_has_calls(
[
mock.call().projects().jobs().create(body=new_job, parent=project_path),
@@ -1203,7 +1204,7 @@ def test_cancel_mlengine_job(self, mock_get_conn, mock_project_id):
# fmt: on
cancel_job_response = self.hook.cancel_job(job_id=job_id, project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST)
- self.assertEqual(cancel_job_response, job_cancelled)
+ assert cancel_job_response == job_cancelled
mock_get_conn.assert_has_calls(
[
mock.call().projects().jobs().cancel(name=job_path),
diff --git a/tests/providers/google/cloud/hooks/test_natural_language.py b/tests/providers/google/cloud/hooks/test_natural_language.py
index 745b304506364..7618d752a842c 100644
--- a/tests/providers/google/cloud/hooks/test_natural_language.py
+++ b/tests/providers/google/cloud/hooks/test_natural_language.py
@@ -53,8 +53,8 @@ def test_language_service_client_creation(self, mock_client, mock_get_creds, moc
mock_client.assert_called_once_with(
credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.hook._conn, result)
+ assert mock_client.return_value == result
+ assert self.hook._conn == result
@mock.patch(
"airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook.get_conn",
@@ -63,7 +63,7 @@ def test_analyze_entities(self, get_conn):
get_conn.return_value.analyze_entities.return_value = API_RESPONSE
result = self.hook.analyze_entities(document=DOCUMENT, encoding_type=ENCODING_TYPE)
- self.assertEqual(result, API_RESPONSE)
+ assert result == API_RESPONSE
get_conn.return_value.analyze_entities.assert_called_once_with(
document=DOCUMENT, encoding_type=ENCODING_TYPE, retry=None, timeout=None, metadata=None
@@ -76,7 +76,7 @@ def test_analyze_entity_sentiment(self, get_conn):
get_conn.return_value.analyze_entity_sentiment.return_value = API_RESPONSE
result = self.hook.analyze_entity_sentiment(document=DOCUMENT, encoding_type=ENCODING_TYPE)
- self.assertEqual(result, API_RESPONSE)
+ assert result == API_RESPONSE
get_conn.return_value.analyze_entity_sentiment.assert_called_once_with(
document=DOCUMENT, encoding_type=ENCODING_TYPE, retry=None, timeout=None, metadata=None
@@ -89,7 +89,7 @@ def test_analyze_sentiment(self, get_conn):
get_conn.return_value.analyze_sentiment.return_value = API_RESPONSE
result = self.hook.analyze_sentiment(document=DOCUMENT, encoding_type=ENCODING_TYPE)
- self.assertEqual(result, API_RESPONSE)
+ assert result == API_RESPONSE
get_conn.return_value.analyze_sentiment.assert_called_once_with(
document=DOCUMENT, encoding_type=ENCODING_TYPE, retry=None, timeout=None, metadata=None
@@ -102,7 +102,7 @@ def test_analyze_syntax(self, get_conn):
get_conn.return_value.analyze_syntax.return_value = API_RESPONSE
result = self.hook.analyze_syntax(document=DOCUMENT, encoding_type=ENCODING_TYPE)
- self.assertEqual(result, API_RESPONSE)
+ assert result == API_RESPONSE
get_conn.return_value.analyze_syntax.assert_called_once_with(
document=DOCUMENT, encoding_type=ENCODING_TYPE, retry=None, timeout=None, metadata=None
@@ -115,7 +115,7 @@ def test_annotate_text(self, get_conn):
get_conn.return_value.annotate_text.return_value = API_RESPONSE
result = self.hook.annotate_text(document=DOCUMENT, encoding_type=ENCODING_TYPE, features=None)
- self.assertEqual(result, API_RESPONSE)
+ assert result == API_RESPONSE
get_conn.return_value.annotate_text.assert_called_once_with(
document=DOCUMENT,
@@ -133,7 +133,7 @@ def test_classify_text(self, get_conn):
get_conn.return_value.classify_text.return_value = API_RESPONSE
result = self.hook.classify_text(document=DOCUMENT)
- self.assertEqual(result, API_RESPONSE)
+ assert result == API_RESPONSE
get_conn.return_value.classify_text.assert_called_once_with(
document=DOCUMENT, retry=None, timeout=None, metadata=None
diff --git a/tests/providers/google/cloud/hooks/test_os_login.py b/tests/providers/google/cloud/hooks/test_os_login.py
index 7e37569539b63..d2b88e4c6c895 100644
--- a/tests/providers/google/cloud/hooks/test_os_login.py
+++ b/tests/providers/google/cloud/hooks/test_os_login.py
@@ -18,6 +18,7 @@
from typing import Dict, Sequence, Tuple
from unittest import TestCase, mock
+import pytest
from google.api_core.retry import Retry
from airflow import AirflowException
@@ -167,7 +168,7 @@ def setUp(
)
@mock.patch("airflow.providers.google.cloud.hooks.os_login.OSLoginHook.get_conn")
def test_import_ssh_public_key(self, mock_get_conn, mock_get_creds_and_project_id) -> None:
- with self.assertRaisesRegex(AirflowException, TEST_MESSAGE):
+ with pytest.raises(AirflowException, match=TEST_MESSAGE):
self.hook.import_ssh_public_key(
user=TEST_USER,
ssh_public_key=TEST_BODY,
diff --git a/tests/providers/google/cloud/hooks/test_pubsub.py b/tests/providers/google/cloud/hooks/test_pubsub.py
index d8328bb00e1a5..eadb8064e560f 100644
--- a/tests/providers/google/cloud/hooks/test_pubsub.py
+++ b/tests/providers/google/cloud/hooks/test_pubsub.py
@@ -21,6 +21,7 @@
from unittest import mock
from uuid import UUID
+import pytest
from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
from google.cloud.exceptions import NotFound
from google.cloud.pubsub_v1.types import ReceivedMessage
@@ -81,13 +82,13 @@ def _generate_messages(self, count) -> List[ReceivedMessage]:
@mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubHook._get_credentials")
@mock.patch("airflow.providers.google.cloud.hooks.pubsub.PublisherClient")
def test_publisher_client_creation(self, mock_client, mock_get_creds, mock_client_info):
- self.assertIsNone(self.pubsub_hook._client)
+ assert self.pubsub_hook._client is None
result = self.pubsub_hook.get_conn()
mock_client.assert_called_once_with(
credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.pubsub_hook._client, result)
+ assert mock_client.return_value == result
+ assert self.pubsub_hook._client == result
@mock.patch(
"airflow.providers.google.cloud.hooks.pubsub.PubSubHook.client_info", new_callable=mock.PropertyMock
@@ -95,12 +96,12 @@ def test_publisher_client_creation(self, mock_client, mock_get_creds, mock_clien
@mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubHook._get_credentials")
@mock.patch("airflow.providers.google.cloud.hooks.pubsub.SubscriberClient")
def test_subscriber_client_creation(self, mock_client, mock_get_creds, mock_client_info):
- self.assertIsNone(self.pubsub_hook._client)
+ assert self.pubsub_hook._client is None
result = self.pubsub_hook.subscriber_client
mock_client.assert_called_once_with(
credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value
)
- self.assertEqual(mock_client.return_value, result)
+ assert mock_client.return_value == result
@mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn'))
def test_create_nonexistent_topic(self, mock_service):
@@ -126,17 +127,17 @@ def test_delete_nonexisting_topic_failifnotexists(self, mock_service):
mock_service.return_value.delete_topic.side_effect = NotFound(
'Topic does not exists: %s' % EXPANDED_TOPIC
)
- with self.assertRaises(PubSubException) as e:
+ with pytest.raises(PubSubException) as ctx:
self.pubsub_hook.delete_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC, fail_if_not_exists=True)
- self.assertEqual(str(e.exception), 'Topic does not exist: %s' % EXPANDED_TOPIC)
+ assert str(ctx.value) == 'Topic does not exist: %s' % EXPANDED_TOPIC
@mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn'))
def test_delete_topic_api_call_error(self, mock_service):
mock_service.return_value.delete_topic.side_effect = GoogleAPICallError(
'Error deleting topic: %s' % EXPANDED_TOPIC
)
- with self.assertRaises(PubSubException):
+ with pytest.raises(PubSubException):
self.pubsub_hook.delete_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC, fail_if_not_exists=True)
@mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn'))
@@ -144,9 +145,9 @@ def test_create_preexisting_topic_failifexists(self, mock_service):
mock_service.return_value.create_topic.side_effect = AlreadyExists(
'Topic already exists: %s' % TEST_TOPIC
)
- with self.assertRaises(PubSubException) as e:
+ with pytest.raises(PubSubException) as ctx:
self.pubsub_hook.create_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC, fail_if_exists=True)
- self.assertEqual(str(e.exception), 'Topic already exists: %s' % TEST_TOPIC)
+ assert str(ctx.value) == 'Topic already exists: %s' % TEST_TOPIC
@mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn'))
def test_create_preexisting_topic_nofailifexists(self, mock_service):
@@ -160,7 +161,7 @@ def test_create_topic_api_call_error(self, mock_service):
mock_service.return_value.create_topic.side_effect = GoogleAPICallError(
'Error creating topic: %s' % TEST_TOPIC
)
- with self.assertRaises(PubSubException):
+ with pytest.raises(PubSubException):
self.pubsub_hook.create_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC, fail_if_exists=True)
@mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
@@ -189,7 +190,7 @@ def test_create_nonexistent_subscription(self, mock_service):
timeout=None,
metadata=(),
)
- self.assertEqual(TEST_SUBSCRIPTION, response)
+ assert TEST_SUBSCRIPTION == response
@mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
def test_create_subscription_different_project_topic(self, mock_service):
@@ -223,7 +224,7 @@ def test_create_subscription_different_project_topic(self, mock_service):
metadata=(),
)
- self.assertEqual(TEST_SUBSCRIPTION, response)
+ assert TEST_SUBSCRIPTION == response
@mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
def test_delete_subscription(self, mock_service):
@@ -238,18 +239,18 @@ def test_delete_nonexisting_subscription_failifnotexists(self, mock_service):
mock_service.delete_subscription.side_effect = NotFound(
'Subscription does not exists: %s' % EXPANDED_SUBSCRIPTION
)
- with self.assertRaises(PubSubException) as e:
+ with pytest.raises(PubSubException) as ctx:
self.pubsub_hook.delete_subscription(
project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, fail_if_not_exists=True
)
- self.assertEqual(str(e.exception), 'Subscription does not exist: %s' % EXPANDED_SUBSCRIPTION)
+ assert str(ctx.value) == 'Subscription does not exist: %s' % EXPANDED_SUBSCRIPTION
@mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
def test_delete_subscription_api_call_error(self, mock_service):
mock_service.delete_subscription.side_effect = GoogleAPICallError(
'Error deleting subscription %s' % EXPANDED_SUBSCRIPTION
)
- with self.assertRaises(PubSubException):
+ with pytest.raises(PubSubException):
self.pubsub_hook.delete_subscription(
project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, fail_if_not_exists=True
)
@@ -282,7 +283,7 @@ def test_create_subscription_without_subscription_name(
timeout=None,
metadata=(),
)
- self.assertEqual('sub-%s' % TEST_UUID, response)
+ assert 'sub-%s' % TEST_UUID == response
@mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
def test_create_subscription_with_ack_deadline(self, mock_service):
@@ -310,7 +311,7 @@ def test_create_subscription_with_ack_deadline(self, mock_service):
timeout=None,
metadata=(),
)
- self.assertEqual(TEST_SUBSCRIPTION, response)
+ assert TEST_SUBSCRIPTION == response
@mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
def test_create_subscription_with_filter(self, mock_service):
@@ -341,25 +342,25 @@ def test_create_subscription_with_filter(self, mock_service):
timeout=None,
metadata=(),
)
- self.assertEqual(TEST_SUBSCRIPTION, response)
+ assert TEST_SUBSCRIPTION == response
@mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
def test_create_subscription_failifexists(self, mock_service):
mock_service.create_subscription.side_effect = AlreadyExists(
'Subscription already exists: %s' % EXPANDED_SUBSCRIPTION
)
- with self.assertRaises(PubSubException) as e:
+ with pytest.raises(PubSubException) as ctx:
self.pubsub_hook.create_subscription(
project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION, fail_if_exists=True
)
- self.assertEqual(str(e.exception), 'Subscription already exists: %s' % EXPANDED_SUBSCRIPTION)
+ assert str(ctx.value) == 'Subscription already exists: %s' % EXPANDED_SUBSCRIPTION
@mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
def test_create_subscription_api_call_error(self, mock_service):
mock_service.create_subscription.side_effect = GoogleAPICallError(
'Error creating subscription %s' % EXPANDED_SUBSCRIPTION
)
- with self.assertRaises(PubSubException):
+ with pytest.raises(PubSubException):
self.pubsub_hook.create_subscription(
project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION, fail_if_exists=True
)
@@ -372,7 +373,7 @@ def test_create_subscription_nofailifexists(self, mock_service):
response = self.pubsub_hook.create_subscription(
project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION
)
- self.assertEqual(TEST_SUBSCRIPTION, response)
+ assert TEST_SUBSCRIPTION == response
@mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn'))
def test_publish(self, mock_service):
@@ -390,7 +391,7 @@ def test_publish_api_call_error(self, mock_service):
publish_method = mock_service.return_value.publish
publish_method.side_effect = GoogleAPICallError(f'Error publishing to topic {EXPANDED_SUBSCRIPTION}')
- with self.assertRaises(PubSubException):
+ with pytest.raises(PubSubException):
self.pubsub_hook.publish(project_id=TEST_PROJECT, topic=TEST_TOPIC, messages=TEST_MESSAGES)
@mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
@@ -414,7 +415,7 @@ def test_pull(self, mock_service):
timeout=None,
metadata=(),
)
- self.assertEqual(pulled_messages, response)
+ assert pulled_messages == response
@mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
def test_pull_no_messages(self, mock_service):
@@ -434,7 +435,7 @@ def test_pull_no_messages(self, mock_service):
timeout=None,
metadata=(),
)
- self.assertListEqual([], response)
+ assert [] == response
@parameterized.expand(
[
@@ -450,7 +451,7 @@ def test_pull_fails_on_exception(self, exception, mock_service):
pull_method = mock_service.pull
pull_method.side_effect = exception
- with self.assertRaises(PubSubException):
+ with pytest.raises(PubSubException):
self.pubsub_hook.pull(project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=10)
pull_method.assert_called_once_with(
request=dict(
@@ -513,7 +514,7 @@ def test_acknowledge_fails_on_exception(self, exception, mock_service):
ack_method = mock_service.acknowledge
ack_method.side_effect = exception
- with self.assertRaises(PubSubException):
+ with pytest.raises(PubSubException):
self.pubsub_hook.acknowledge(
project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, ack_ids=['1', '2', '3']
)
@@ -559,6 +560,6 @@ def test_messages_validation_positive(self, messages):
]
)
def test_messages_validation_negative(self, messages, error_message):
- with self.assertRaises(PubSubException) as e:
+ with pytest.raises(PubSubException) as ctx:
PubSubHook._validate_messages(messages)
- self.assertEqual(str(e.exception), error_message)
+ assert str(ctx.value) == error_message
diff --git a/tests/providers/google/cloud/hooks/test_secret_manager.py b/tests/providers/google/cloud/hooks/test_secret_manager.py
index 37a492d740933..1f6e9d358533f 100644
--- a/tests/providers/google/cloud/hooks/test_secret_manager.py
+++ b/tests/providers/google/cloud/hooks/test_secret_manager.py
@@ -48,7 +48,7 @@ def test_get_missing_key(self, mock_get_credentials, mock_client):
secret = secrets_manager_hook.get_secret(secret_id="secret")
mock_client.secret_version_path.assert_called_once_with('example-project', 'secret', 'latest')
mock_client.access_secret_version.assert_called_once_with("full-path")
- self.assertIsNone(secret)
+ assert secret is None
@patch(INTERNAL_CLIENT_PACKAGE + "._SecretManagerClient.client", return_value=MagicMock())
@patch(
@@ -66,4 +66,4 @@ def test_get_existing_key(self, mock_get_credentials, mock_client):
secret = secrets_manager_hook.get_secret(secret_id="secret")
mock_client.secret_version_path.assert_called_once_with('example-project', 'secret', 'latest')
mock_client.access_secret_version.assert_called_once_with("full-path")
- self.assertEqual("result", secret)
+ assert "result" == secret
diff --git a/tests/providers/google/cloud/hooks/test_secret_manager_system.py b/tests/providers/google/cloud/hooks/test_secret_manager_system.py
index d16e7f6e9619b..f0786fa012b69 100644
--- a/tests/providers/google/cloud/hooks/test_secret_manager_system.py
+++ b/tests/providers/google/cloud/hooks/test_secret_manager_system.py
@@ -54,22 +54,22 @@ class TestSystemSecretsManager(TestCase):
def test_read_secret_from_secret_manager(self):
hook = SecretsManagerHook()
secret = hook.get_secret(secret_id=TEST_SECRET_ID)
- self.assertEqual(TEST_SECRET_VALUE, secret)
+ assert TEST_SECRET_VALUE == secret
@pytest.mark.usefixtures("helper_one_version")
@provide_gcp_context(GCP_SECRET_MANAGER_KEY)
def test_read_missing_secret_from_secret_manager(self):
hook = SecretsManagerHook()
secret = hook.get_secret(secret_id=TEST_MISSING_SECRET_ID)
- self.assertIsNone(secret)
+ assert secret is None
@pytest.mark.usefixtures("helper_two_versions")
@provide_gcp_context(GCP_SECRET_MANAGER_KEY)
def test_read_secret_different_versions_from_secret_manager(self):
hook = SecretsManagerHook()
secret = hook.get_secret(secret_id=TEST_SECRET_ID)
- self.assertEqual(TEST_SECRET_VALUE_UPDATED, secret)
+ assert TEST_SECRET_VALUE_UPDATED == secret
secret = hook.get_secret(secret_id=TEST_SECRET_ID, secret_version='1')
- self.assertEqual(TEST_SECRET_VALUE, secret)
+ assert TEST_SECRET_VALUE == secret
secret = hook.get_secret(secret_id=TEST_SECRET_ID, secret_version='2')
- self.assertEqual(TEST_SECRET_VALUE_UPDATED, secret)
+ assert TEST_SECRET_VALUE_UPDATED == secret
diff --git a/tests/providers/google/cloud/hooks/test_spanner.py b/tests/providers/google/cloud/hooks/test_spanner.py
index 27b7a06e67456..6cc6c25547bd1 100644
--- a/tests/providers/google/cloud/hooks/test_spanner.py
+++ b/tests/providers/google/cloud/hooks/test_spanner.py
@@ -52,8 +52,8 @@ def test_spanner_client_creation(self, mock_client, mock_get_creds, mock_client_
credentials=mock_get_creds.return_value,
client_info=mock_client_info.return_value,
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.spanner_hook_default_project_id._client, result)
+ assert mock_client.return_value == result
+ assert self.spanner_hook_default_project_id._client == result
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_get_existing_instance(self, get_client):
@@ -65,7 +65,7 @@ def test_get_existing_instance(self, get_client):
)
get_client.assert_called_once_with(project_id='example-project')
instance_method.assert_called_once_with(instance_id='instance')
- self.assertIsNotNone(res)
+ assert res is not None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_get_existing_instance_overridden_project_id(self, get_client):
@@ -77,7 +77,7 @@ def test_get_existing_instance_overridden_project_id(self, get_client):
)
get_client.assert_called_once_with(project_id='new-project')
instance_method.assert_called_once_with(instance_id='instance')
- self.assertIsNotNone(res)
+ assert res is not None
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -103,7 +103,7 @@ def test_create_instance(self, get_client, mock_project_id):
display_name='database-name',
node_count=1,
)
- self.assertIsNone(res)
+ assert res is None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_create_instance_overridden_project_id(self, get_client):
@@ -124,7 +124,7 @@ def test_create_instance_overridden_project_id(self, get_client):
display_name='database-name',
node_count=1,
)
- self.assertIsNone(res)
+ assert res is None
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -153,7 +153,7 @@ def test_update_instance(self, get_client, mock_project_id):
node_count=2,
)
update_method.assert_called_once_with()
- self.assertIsNone(res)
+ assert res is None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_update_instance_overridden_project_id(self, get_client):
@@ -177,7 +177,7 @@ def test_update_instance_overridden_project_id(self, get_client):
node_count=2,
)
update_method.assert_called_once_with()
- self.assertIsNone(res)
+ assert res is None
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -198,7 +198,7 @@ def test_delete_instance(self, get_client, mock_project_id):
get_client.assert_called_once_with(project_id='example-project')
instance_method.assert_called_once_with('instance')
delete_method.assert_called_once_with()
- self.assertIsNone(res)
+ assert res is None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_delete_instance_overridden_project_id(self, get_client):
@@ -213,7 +213,7 @@ def test_delete_instance_overridden_project_id(self, get_client):
get_client.assert_called_once_with(project_id='new-project')
instance_method.assert_called_once_with('instance')
delete_method.assert_called_once_with()
- self.assertIsNone(res)
+ assert res is None
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -237,7 +237,7 @@ def test_get_database(self, get_client, mock_project_id):
instance_method.assert_called_once_with(instance_id='instance')
database_method.assert_called_once_with(database_id='database-name')
database_exists_method.assert_called_once_with()
- self.assertIsNotNone(res)
+ assert res is not None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_get_database_overridden_project_id(self, get_client):
@@ -254,7 +254,7 @@ def test_get_database_overridden_project_id(self, get_client):
instance_method.assert_called_once_with(instance_id='instance')
database_method.assert_called_once_with(database_id='database-name')
database_exists_method.assert_called_once_with()
- self.assertIsNotNone(res)
+ assert res is not None
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -278,7 +278,7 @@ def test_create_database(self, get_client, mock_project_id):
instance_method.assert_called_once_with(instance_id='instance')
database_method.assert_called_once_with(database_id='database-name', ddl_statements=[])
database_create_method.assert_called_once_with()
- self.assertIsNone(res)
+ assert res is None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_create_database_overridden_project_id(self, get_client):
@@ -297,7 +297,7 @@ def test_create_database_overridden_project_id(self, get_client):
instance_method.assert_called_once_with(instance_id='instance')
database_method.assert_called_once_with(database_id='database-name', ddl_statements=[])
database_create_method.assert_called_once_with()
- self.assertIsNone(res)
+ assert res is None
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -321,7 +321,7 @@ def test_update_database(self, get_client, mock_project_id):
instance_method.assert_called_once_with(instance_id='instance')
database_method.assert_called_once_with(database_id='database-name')
database_update_ddl_method.assert_called_once_with(ddl_statements=[], operation_id=None)
- self.assertIsNone(res)
+ assert res is None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_update_database_overridden_project_id(self, get_client):
@@ -340,7 +340,7 @@ def test_update_database_overridden_project_id(self, get_client):
instance_method.assert_called_once_with(instance_id='instance')
database_method.assert_called_once_with(database_id='database-name')
database_update_ddl_method.assert_called_once_with(ddl_statements=[], operation_id=None)
- self.assertIsNone(res)
+ assert res is None
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -366,7 +366,7 @@ def test_delete_database(self, get_client, mock_project_id):
database_method.assert_called_once_with(database_id='database-name')
database_exists_method.assert_called_once_with()
database_drop_method.assert_called_once_with()
- self.assertTrue(res)
+ assert res
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_delete_database_overridden_project_id(self, get_client):
@@ -385,7 +385,7 @@ def test_delete_database_overridden_project_id(self, get_client):
database_method.assert_called_once_with(database_id='database-name')
database_exists_method.assert_called_once_with()
database_drop_method.assert_called_once_with()
- self.assertTrue(res)
+ assert res
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -409,7 +409,7 @@ def test_execute_dml(self, get_client, mock_project_id):
instance_method.assert_called_once_with(instance_id='instance')
database_method.assert_called_once_with(database_id='database-name')
run_in_transaction_method.assert_called_once_with(mock.ANY)
- self.assertIsNone(res)
+ assert res is None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_execute_dml_overridden_project_id(self, get_client):
@@ -425,7 +425,7 @@ def test_execute_dml_overridden_project_id(self, get_client):
instance_method.assert_called_once_with(instance_id='instance')
database_method.assert_called_once_with(database_id='database-name')
run_in_transaction_method.assert_called_once_with(mock.ANY)
- self.assertIsNone(res)
+ assert res is None
class TestGcpSpannerHookNoDefaultProjectID(unittest.TestCase):
@@ -451,8 +451,8 @@ def test_spanner_client_creation(self, mock_client, mock_get_creds, mock_client_
credentials=mock_get_creds.return_value,
client_info=mock_client_info.return_value,
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.spanner_hook_no_default_project_id._client, result)
+ assert mock_client.return_value == result
+ assert self.spanner_hook_no_default_project_id._client == result
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_get_existing_instance_overridden_project_id(self, get_client):
@@ -464,7 +464,7 @@ def test_get_existing_instance_overridden_project_id(self, get_client):
)
get_client.assert_called_once_with(project_id='example-project')
instance_method.assert_called_once_with(instance_id='instance')
- self.assertIsNotNone(res)
+ assert res is not None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_get_non_existing_instance(self, get_client):
@@ -476,7 +476,7 @@ def test_get_non_existing_instance(self, get_client):
)
get_client.assert_called_once_with(project_id='example-project')
instance_method.assert_called_once_with(instance_id='instance')
- self.assertIsNone(res)
+ assert res is None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_create_instance_overridden_project_id(self, get_client):
@@ -497,7 +497,7 @@ def test_create_instance_overridden_project_id(self, get_client):
display_name='database-name',
node_count=1,
)
- self.assertIsNone(res)
+ assert res is None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_update_instance_overridden_project_id(self, get_client):
@@ -521,7 +521,7 @@ def test_update_instance_overridden_project_id(self, get_client):
node_count=2,
)
update_method.assert_called_once_with()
- self.assertIsNone(res)
+ assert res is None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_delete_instance_overridden_project_id(self, get_client):
@@ -536,7 +536,7 @@ def test_delete_instance_overridden_project_id(self, get_client):
get_client.assert_called_once_with(project_id='example-project')
instance_method.assert_called_once_with('instance')
delete_method.assert_called_once_with()
- self.assertIsNone(res)
+ assert res is None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_get_database_overridden_project_id(self, get_client):
@@ -555,7 +555,7 @@ def test_get_database_overridden_project_id(self, get_client):
instance_method.assert_called_once_with(instance_id='instance')
database_method.assert_called_once_with(database_id='database-name')
database_exists_method.assert_called_once_with()
- self.assertIsNotNone(res)
+ assert res is not None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_create_database_overridden_project_id(self, get_client):
@@ -574,7 +574,7 @@ def test_create_database_overridden_project_id(self, get_client):
instance_method.assert_called_once_with(instance_id='instance')
database_method.assert_called_once_with(database_id='database-name', ddl_statements=[])
database_create_method.assert_called_once_with()
- self.assertIsNone(res)
+ assert res is None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_update_database_overridden_project_id(self, get_client):
@@ -593,7 +593,7 @@ def test_update_database_overridden_project_id(self, get_client):
instance_method.assert_called_once_with(instance_id='instance')
database_method.assert_called_once_with(database_id='database-name')
database_update_ddl_method.assert_called_once_with(ddl_statements=[], operation_id=None)
- self.assertIsNone(res)
+ assert res is None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_update_database_overridden_project_id_and_operation(self, get_client):
@@ -613,7 +613,7 @@ def test_update_database_overridden_project_id_and_operation(self, get_client):
instance_method.assert_called_once_with(instance_id='instance')
database_method.assert_called_once_with(database_id='database-name')
database_update_ddl_method.assert_called_once_with(ddl_statements=[], operation_id="operation")
- self.assertIsNone(res)
+ assert res is None
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_delete_database_overridden_project_id(self, get_client):
@@ -634,7 +634,7 @@ def test_delete_database_overridden_project_id(self, get_client):
database_method.assert_called_once_with(database_id='database-name')
database_exists_method.assert_called_once_with()
database_drop_method.assert_called_once_with()
- self.assertTrue(res)
+ assert res
@mock.patch('airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client')
def test_delete_database_missing_database(self, get_client):
@@ -673,4 +673,4 @@ def test_execute_dml_overridden_project_id(self, get_client):
instance_method.assert_called_once_with(instance_id='instance')
database_method.assert_called_once_with(database_id='database-name')
run_in_transaction_method.assert_called_once_with(mock.ANY)
- self.assertIsNone(res)
+ assert res is None
diff --git a/tests/providers/google/cloud/hooks/test_speech_to_text.py b/tests/providers/google/cloud/hooks/test_speech_to_text.py
index be5d933137823..924d73d62b9f3 100644
--- a/tests/providers/google/cloud/hooks/test_speech_to_text.py
+++ b/tests/providers/google/cloud/hooks/test_speech_to_text.py
@@ -47,8 +47,8 @@ def test_speech_client_creation(self, mock_client, mock_get_creds, mock_client_i
mock_client.assert_called_once_with(
credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.gcp_speech_to_text_hook._client, result)
+ assert mock_client.return_value == result
+ assert self.gcp_speech_to_text_hook._client == result
@patch("airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook.get_conn")
def test_synthesize_speech(self, get_conn):
diff --git a/tests/providers/google/cloud/hooks/test_tasks.py b/tests/providers/google/cloud/hooks/test_tasks.py
index 65045953b2b7c..5abd7616c2732 100644
--- a/tests/providers/google/cloud/hooks/test_tasks.py
+++ b/tests/providers/google/cloud/hooks/test_tasks.py
@@ -54,8 +54,8 @@ def test_cloud_tasks_client_creation(self, mock_client, mock_get_creds, mock_cli
mock_client.assert_called_once_with(
credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.hook._client, result)
+ assert mock_client.return_value == result
+ assert self.hook._client == result
@mock.patch(
"airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
@@ -69,7 +69,7 @@ def test_create_queue(self, get_conn):
project_id=PROJECT_ID,
)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.create_queue.assert_called_once_with(
request=dict(parent=FULL_LOCATION_PATH, queue=Queue(name=FULL_QUEUE_PATH)),
@@ -90,7 +90,7 @@ def test_update_queue(self, get_conn):
project_id=PROJECT_ID,
)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.update_queue.assert_called_once_with(
request=dict(queue=Queue(name=FULL_QUEUE_PATH, state=3), update_mask=None),
@@ -106,7 +106,7 @@ def test_update_queue(self, get_conn):
def test_get_queue(self, get_conn):
result = self.hook.get_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
get_conn.return_value.get_queue.assert_called_once_with(
request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=()
@@ -119,7 +119,7 @@ def test_get_queue(self, get_conn):
def test_list_queues(self, get_conn):
result = self.hook.list_queues(location=LOCATION, project_id=PROJECT_ID)
- self.assertEqual(result, [Queue(name=FULL_QUEUE_PATH)])
+ assert result == [Queue(name=FULL_QUEUE_PATH)]
get_conn.return_value.list_queues.assert_called_once_with(
request=dict(parent=FULL_LOCATION_PATH, filter=None, page_size=None),
@@ -135,7 +135,7 @@ def test_list_queues(self, get_conn):
def test_delete_queue(self, get_conn):
result = self.hook.delete_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)
- self.assertEqual(result, None)
+ assert result is None
get_conn.return_value.delete_queue.assert_called_once_with(
request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=()
@@ -148,7 +148,7 @@ def test_delete_queue(self, get_conn):
def test_purge_queue(self, get_conn):
result = self.hook.purge_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)
- self.assertEqual(result, Queue(name=FULL_QUEUE_PATH))
+ assert result == Queue(name=FULL_QUEUE_PATH)
get_conn.return_value.purge_queue.assert_called_once_with(
request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=()
@@ -161,7 +161,7 @@ def test_purge_queue(self, get_conn):
def test_pause_queue(self, get_conn):
result = self.hook.pause_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)
- self.assertEqual(result, Queue(name=FULL_QUEUE_PATH))
+ assert result == Queue(name=FULL_QUEUE_PATH)
get_conn.return_value.pause_queue.assert_called_once_with(
request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=()
@@ -174,7 +174,7 @@ def test_pause_queue(self, get_conn):
def test_resume_queue(self, get_conn):
result = self.hook.resume_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)
- self.assertEqual(result, Queue(name=FULL_QUEUE_PATH))
+ assert result == Queue(name=FULL_QUEUE_PATH)
get_conn.return_value.resume_queue.assert_called_once_with(
request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=()
@@ -193,7 +193,7 @@ def test_create_task(self, get_conn):
task_name=TASK_NAME,
)
- self.assertEqual(result, Task(name=FULL_TASK_PATH))
+ assert result == Task(name=FULL_TASK_PATH)
get_conn.return_value.create_task.assert_called_once_with(
request=dict(parent=FULL_QUEUE_PATH, task=Task(name=FULL_TASK_PATH), response_view=None),
@@ -214,7 +214,7 @@ def test_get_task(self, get_conn):
project_id=PROJECT_ID,
)
- self.assertEqual(result, Task(name=FULL_TASK_PATH))
+ assert result == Task(name=FULL_TASK_PATH)
get_conn.return_value.get_task.assert_called_once_with(
request=dict(name=FULL_TASK_PATH, response_view=None),
@@ -230,7 +230,7 @@ def test_get_task(self, get_conn):
def test_list_tasks(self, get_conn):
result = self.hook.list_tasks(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)
- self.assertEqual(result, [Task(name=FULL_TASK_PATH)])
+ assert result == [Task(name=FULL_TASK_PATH)]
get_conn.return_value.list_tasks.assert_called_once_with(
request=dict(parent=FULL_QUEUE_PATH, response_view=None, page_size=None),
@@ -251,7 +251,7 @@ def test_delete_task(self, get_conn):
project_id=PROJECT_ID,
)
- self.assertEqual(result, None)
+ assert result is None
get_conn.return_value.delete_task.assert_called_once_with(
request=dict(name=FULL_TASK_PATH), retry=None, timeout=None, metadata=()
@@ -269,7 +269,7 @@ def test_run_task(self, get_conn):
project_id=PROJECT_ID,
)
- self.assertEqual(result, Task(name=FULL_TASK_PATH))
+ assert result == Task(name=FULL_TASK_PATH)
get_conn.return_value.run_task.assert_called_once_with(
request=dict(name=FULL_TASK_PATH, response_view=None),
diff --git a/tests/providers/google/cloud/hooks/test_text_to_speech.py b/tests/providers/google/cloud/hooks/test_text_to_speech.py
index cc627a380b0dd..87e7b94d6d0f2 100644
--- a/tests/providers/google/cloud/hooks/test_text_to_speech.py
+++ b/tests/providers/google/cloud/hooks/test_text_to_speech.py
@@ -47,8 +47,8 @@ def test_text_to_speech_client_creation(self, mock_client, mock_get_creds, mock_
mock_client.assert_called_once_with(
credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.gcp_text_to_speech_hook._client, result)
+ assert mock_client.return_value == result
+ assert self.gcp_text_to_speech_hook._client == result
@patch("airflow.providers.google.cloud.hooks.text_to_speech.CloudTextToSpeechHook.get_conn")
def test_synthesize_speech(self, get_conn):
diff --git a/tests/providers/google/cloud/hooks/test_translate.py b/tests/providers/google/cloud/hooks/test_translate.py
index 43c559c3b1a35..99b27f7d8e063 100644
--- a/tests/providers/google/cloud/hooks/test_translate.py
+++ b/tests/providers/google/cloud/hooks/test_translate.py
@@ -44,8 +44,8 @@ def test_translate_client_creation(self, mock_client, mock_get_creds, mock_clien
mock_client.assert_called_once_with(
credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.hook._client, result)
+ assert mock_client.return_value == result
+ assert self.hook._client == result
@mock.patch('airflow.providers.google.cloud.hooks.translate.CloudTranslateHook.get_conn')
def test_translate_called(self, get_conn):
@@ -66,15 +66,12 @@ def test_translate_called(self, get_conn):
model='base',
)
# Then
- self.assertEqual(
- result,
- {
- 'translatedText': 'Yellowing self Gęśle',
- 'detectedSourceLanguage': 'pl',
- 'model': 'base',
- 'input': 'zażółć gęślą jaźń',
- },
- )
+ assert result == {
+ 'translatedText': 'Yellowing self Gęśle',
+ 'detectedSourceLanguage': 'pl',
+ 'model': 'base',
+ 'input': 'zażółć gęślą jaźń',
+ }
translate_method.assert_called_once_with(
values=['zażółć gęślą jaźń'],
target_language='en',
diff --git a/tests/providers/google/cloud/hooks/test_video_intelligence.py b/tests/providers/google/cloud/hooks/test_video_intelligence.py
index 624d0c597b95c..4715ac1064db9 100644
--- a/tests/providers/google/cloud/hooks/test_video_intelligence.py
+++ b/tests/providers/google/cloud/hooks/test_video_intelligence.py
@@ -53,8 +53,8 @@ def test_video_intelligence_service_client_creation(self, mock_client, mock_get_
mock_client.assert_called_once_with(
credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.hook._conn, result)
+ assert mock_client.return_value == result
+ assert self.hook._conn == result
@mock.patch("airflow.providers.google.cloud.hooks.video_intelligence.CloudVideoIntelligenceHook.get_conn")
def test_annotate_video(self, get_conn):
@@ -66,7 +66,7 @@ def test_annotate_video(self, get_conn):
result = self.hook.annotate_video(input_uri=INPUT_URI, features=FEATURES)
# Then
- self.assertIs(result, ANNOTATE_VIDEO_RESPONSE)
+ assert result is ANNOTATE_VIDEO_RESPONSE
annotate_video_method.assert_called_once_with(
input_uri=INPUT_URI,
input_content=None,
@@ -89,7 +89,7 @@ def test_annotate_video_with_output_uri(self, get_conn):
result = self.hook.annotate_video(input_uri=INPUT_URI, output_uri=OUTPUT_URI, features=FEATURES)
# Then
- self.assertIs(result, ANNOTATE_VIDEO_RESPONSE)
+ assert result is ANNOTATE_VIDEO_RESPONSE
annotate_video_method.assert_called_once_with(
input_uri=INPUT_URI,
output_uri=OUTPUT_URI,
diff --git a/tests/providers/google/cloud/hooks/test_vision.py b/tests/providers/google/cloud/hooks/test_vision.py
index 31f004c98419f..09edb53845240 100644
--- a/tests/providers/google/cloud/hooks/test_vision.py
+++ b/tests/providers/google/cloud/hooks/test_vision.py
@@ -18,6 +18,7 @@
import unittest
from unittest import mock
+import pytest
from google.cloud.vision import enums
from google.cloud.vision_v1 import ProductSearchClient
from google.cloud.vision_v1.proto.image_annotator_pb2 import (
@@ -92,8 +93,8 @@ def test_product_search_client_creation(self, mock_client, mock_get_creds, mock_
mock_client.assert_called_once_with(
credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value
)
- self.assertEqual(mock_client.return_value, result)
- self.assertEqual(self.hook._client, result)
+ assert mock_client.return_value == result
+ assert self.hook._client == result
@mock.patch('airflow.providers.google.cloud.hooks.vision.CloudVisionHook.get_conn')
def test_create_productset_explicit_id(self, get_conn):
@@ -115,7 +116,7 @@ def test_create_productset_explicit_id(self, get_conn):
# Then
# ProductSet ID was provided explicitly in the method call above, should be returned from the method
- self.assertEqual(result, PRODUCTSET_ID_TEST)
+ assert result == PRODUCTSET_ID_TEST
create_product_set_method.assert_called_once_with(
parent=parent,
product_set=product_set,
@@ -143,7 +144,7 @@ def test_create_productset_autogenerated_id(self, get_conn):
# Then
# ProductSet ID was not provided in the method call above. Should be extracted from the API response
# and returned.
- self.assertEqual(result, autogenerated_id)
+ assert result == autogenerated_id
create_product_set_method.assert_called_once_with(
parent=parent,
product_set=product_set,
@@ -162,7 +163,7 @@ def test_create_productset_autogenerated_id_wrong_api_response(self, get_conn):
parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST)
product_set = ProductSet()
# When
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
self.hook.create_product_set(
location=LOC_ID_TEST,
product_set_id=None,
@@ -174,8 +175,8 @@ def test_create_productset_autogenerated_id_wrong_api_response(self, get_conn):
)
# Then
# API response was wrong (None) and thus ProductSet ID extraction should fail.
- err = cm.exception
- self.assertIn('Unable to get name from response...', str(err))
+ err = ctx.value
+ assert 'Unable to get name from response...' in str(err)
create_product_set_method.assert_called_once_with(
parent=parent,
product_set=product_set,
@@ -197,8 +198,8 @@ def test_get_productset(self, get_conn):
location=LOC_ID_TEST, product_set_id=PRODUCTSET_ID_TEST, project_id=PROJECT_ID_TEST
)
# Then
- self.assertTrue(response)
- self.assertEqual(response, MessageToDict(response_product_set))
+ assert response
+ assert response == MessageToDict(response_product_set)
get_product_set_method.assert_called_once_with(name=name, retry=None, timeout=None, metadata=None)
@mock.patch('airflow.providers.google.cloud.hooks.vision.CloudVisionHook.get_conn')
@@ -222,7 +223,7 @@ def test_update_productset_no_explicit_name(self, get_conn):
metadata=None,
)
# Then
- self.assertEqual(result, MessageToDict(product_set))
+ assert result == MessageToDict(product_set)
update_product_set_method.assert_called_once_with(
product_set=ProductSet(name=productset_name),
metadata=None,
@@ -241,7 +242,7 @@ def test_update_productset_no_explicit_name_and_missing_params_for_constructed_n
update_product_set_method.return_value = None
product_set = ProductSet()
# When
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
self.hook.update_product_set(
location=location,
product_set_id=product_set_id,
@@ -252,9 +253,9 @@ def test_update_productset_no_explicit_name_and_missing_params_for_constructed_n
timeout=None,
metadata=None,
)
- err = cm.exception
- self.assertTrue(err)
- self.assertIn(ERR_UNABLE_TO_CREATE.format(label='ProductSet', id_label='productset_id'), str(err))
+ err = ctx.value
+ assert err
+ assert ERR_UNABLE_TO_CREATE.format(label='ProductSet', id_label='productset_id') in str(err)
update_product_set_method.assert_not_called()
@parameterized.expand([(None, None), (None, PRODUCTSET_ID_TEST), (LOC_ID_TEST, None)])
@@ -281,7 +282,7 @@ def test_update_productset_explicit_name_missing_params_for_constructed_name(
metadata=None,
)
# Then
- self.assertEqual(result, MessageToDict(product_set))
+ assert result == MessageToDict(product_set)
update_product_set_method.assert_called_once_with(
product_set=ProductSet(name=explicit_ps_name),
metadata=None,
@@ -306,7 +307,7 @@ def test_update_productset_explicit_name_different_from_constructed(self, get_co
# Location and product_set_id are passed in addition to a ProductSet with an explicit name,
# but both names differ (constructed != explicit).
# Should throw AirflowException in this case.
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
self.hook.update_product_set(
location=LOC_ID_TEST,
product_set_id=PRODUCTSET_ID_TEST,
@@ -317,17 +318,17 @@ def test_update_productset_explicit_name_different_from_constructed(self, get_co
timeout=None,
metadata=None,
)
- err = cm.exception
+ err = ctx.value
# self.assertIn("The required parameter 'project_id' is missing", str(err))
- self.assertTrue(err)
- self.assertIn(
+ assert err
+ assert (
ERR_DIFF_NAMES.format(
explicit_name=explicit_ps_name,
constructed_name=template_ps_name,
label="ProductSet",
id_label="productset_id",
- ),
- str(err),
+ )
+ in str(err)
)
update_product_set_method.assert_not_called()
@@ -342,7 +343,7 @@ def test_delete_productset(self, get_conn):
location=LOC_ID_TEST, product_set_id=PRODUCTSET_ID_TEST, project_id=PROJECT_ID_TEST
)
# Then
- self.assertIsNone(response)
+ assert response is None
delete_product_set_method.assert_called_once_with(name=name, retry=None, timeout=None, metadata=None)
@mock.patch(
@@ -363,7 +364,7 @@ def test_create_reference_image_explicit_id(self, get_conn):
)
# Then
# Product ID was provided explicitly in the method call above, should be returned from the method
- self.assertEqual(result, REFERENCE_IMAGE_ID_TEST)
+ assert result == REFERENCE_IMAGE_ID_TEST
create_reference_image_method.assert_called_once_with(
parent=PRODUCT_NAME,
reference_image=REFERENCE_IMAGE_WITHOUT_ID_NAME,
@@ -391,7 +392,7 @@ def test_create_reference_image_autogenerated_id(self, get_conn):
)
# Then
# Product ID was provided explicitly in the method call above, should be returned from the method
- self.assertEqual(result, REFERENCE_IMAGE_GEN_ID_TEST)
+ assert result == REFERENCE_IMAGE_GEN_ID_TEST
create_reference_image_method.assert_called_once_with(
parent=PRODUCT_NAME,
reference_image=REFERENCE_IMAGE_TEST,
@@ -477,7 +478,7 @@ def test_create_product_explicit_id(self, get_conn):
)
# Then
# Product ID was provided explicitly in the method call above, should be returned from the method
- self.assertEqual(result, PRODUCT_ID_TEST)
+ assert result == PRODUCT_ID_TEST
create_product_method.assert_called_once_with(
parent=parent,
product=product,
@@ -505,7 +506,7 @@ def test_create_product_autogenerated_id(self, get_conn):
# Then
# Product ID was not provided in the method call above. Should be extracted from the API response
# and returned.
- self.assertEqual(result, autogenerated_id)
+ assert result == autogenerated_id
create_product_method.assert_called_once_with(
parent=parent, product=product, product_id=None, retry=None, timeout=None, metadata=None
)
@@ -520,14 +521,14 @@ def test_create_product_autogenerated_id_wrong_name_in_response(self, get_conn):
parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST)
product = Product()
# When
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
self.hook.create_product(
location=LOC_ID_TEST, product_id=None, product=product, project_id=PROJECT_ID_TEST
)
# Then
# API response was wrong (wrong name format) and thus ProductSet ID extraction should fail.
- err = cm.exception
- self.assertIn('Unable to get id from name', str(err))
+ err = ctx.value
+ assert 'Unable to get id from name' in str(err)
create_product_method.assert_called_once_with(
parent=parent, product=product, product_id=None, retry=None, timeout=None, metadata=None
)
@@ -541,14 +542,14 @@ def test_create_product_autogenerated_id_wrong_api_response(self, get_conn):
parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST)
product = Product()
# When
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
self.hook.create_product(
location=LOC_ID_TEST, product_id=None, product=product, project_id=PROJECT_ID_TEST
)
# Then
# API response was wrong (None) and thus ProductSet ID extraction should fail.
- err = cm.exception
- self.assertIn('Unable to get name from response...', str(err))
+ err = ctx.value
+ assert 'Unable to get name from response...' in str(err)
create_product_method.assert_called_once_with(
parent=parent, product=product, product_id=None, retry=None, timeout=None, metadata=None
)
@@ -572,7 +573,7 @@ def test_update_product_no_explicit_name(self, get_conn):
metadata=None,
)
# Then
- self.assertEqual(result, MessageToDict(product))
+ assert result == MessageToDict(product)
update_product_method.assert_called_once_with(
product=Product(name=product_name), metadata=None, retry=None, timeout=None, update_mask=None
)
@@ -587,7 +588,7 @@ def test_update_product_no_explicit_name_and_missing_params_for_constructed_name
update_product_method.return_value = None
product = Product()
# When
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
self.hook.update_product(
location=location,
product_id=product_id,
@@ -598,12 +599,9 @@ def test_update_product_no_explicit_name_and_missing_params_for_constructed_name
timeout=None,
metadata=None,
)
- err = cm.exception
- self.assertTrue(err)
- self.assertIn(
- ERR_UNABLE_TO_CREATE.format(label='Product', id_label='product_id'),
- str(err),
- )
+ err = ctx.value
+ assert err
+ assert ERR_UNABLE_TO_CREATE.format(label='Product', id_label='product_id') in str(err)
update_product_method.assert_not_called()
@parameterized.expand([(None, None), (None, PRODUCT_ID_TEST), (LOC_ID_TEST, None)])
@@ -630,7 +628,7 @@ def test_update_product_explicit_name_missing_params_for_constructed_name(
metadata=None,
)
# Then
- self.assertEqual(result, MessageToDict(product))
+ assert result == MessageToDict(product)
update_product_method.assert_called_once_with(
product=Product(name=explicit_p_name), metadata=None, retry=None, timeout=None, update_mask=None
)
@@ -649,7 +647,7 @@ def test_update_product_explicit_name_different_from_constructed(self, get_conn)
# Location and product_id are passed in addition to a Product with an explicit name,
# but both names differ (constructed != explicit).
# Should throw AirflowException in this case.
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
self.hook.update_product(
location=LOC_ID_TEST,
product_id=PRODUCT_ID_TEST,
@@ -660,16 +658,16 @@ def test_update_product_explicit_name_different_from_constructed(self, get_conn)
timeout=None,
metadata=None,
)
- err = cm.exception
- self.assertTrue(err)
- self.assertIn(
+ err = ctx.value
+ assert err
+ assert (
ERR_DIFF_NAMES.format(
explicit_name=explicit_p_name,
constructed_name=template_p_name,
label="Product",
id_label="product_id",
- ),
- str(err),
+ )
+ in str(err)
)
update_product_method.assert_not_called()
@@ -684,7 +682,7 @@ def test_delete_product(self, get_conn):
location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, project_id=PROJECT_ID_TEST
)
# Then
- self.assertIsNone(response)
+ assert response is None
delete_product_method.assert_called_once_with(name=name, retry=None, timeout=None, metadata=None)
@mock.patch("airflow.providers.google.cloud.hooks.vision.CloudVisionHook.annotator_client")
@@ -730,11 +728,11 @@ def test_detect_text_with_error_response(self, annotator_client_mock):
)
# When
- with self.assertRaises(AirflowException) as msg:
+ with pytest.raises(AirflowException) as ctx:
self.hook.text_detection(image=DETECT_TEST_IMAGE)
- err = msg.exception
- self.assertIn("test error message", str(err))
+ err = ctx.value
+ assert "test error message" in str(err)
@mock.patch("airflow.providers.google.cloud.hooks.vision.CloudVisionHook.annotator_client")
def test_document_text_detection(self, annotator_client_mock):
@@ -779,11 +777,11 @@ def test_detect_document_text_with_error_response(self, annotator_client_mock):
)
# When
- with self.assertRaises(AirflowException) as msg:
+ with pytest.raises(AirflowException) as ctx:
self.hook.document_text_detection(image=DETECT_TEST_IMAGE)
- err = msg.exception
- self.assertIn("test error message", str(err))
+ err = ctx.value
+ assert "test error message" in str(err)
@mock.patch("airflow.providers.google.cloud.hooks.vision.CloudVisionHook.annotator_client")
def test_label_detection(self, annotator_client_mock):
@@ -828,11 +826,11 @@ def test_label_detection_with_error_response(self, annotator_client_mock):
)
# When
- with self.assertRaises(AirflowException) as msg:
+ with pytest.raises(AirflowException) as ctx:
self.hook.label_detection(image=DETECT_TEST_IMAGE)
- err = msg.exception
- self.assertIn("test error message", str(err))
+ err = ctx.value
+ assert "test error message" in str(err)
@mock.patch("airflow.providers.google.cloud.hooks.vision.CloudVisionHook.annotator_client")
def test_safe_search_detection(self, annotator_client_mock):
@@ -889,8 +887,8 @@ def test_safe_search_detection_with_error_response(self, annotator_client_mock):
)
# When
- with self.assertRaises(AirflowException) as msg:
+ with pytest.raises(AirflowException) as ctx:
self.hook.safe_search_detection(image=DETECT_TEST_IMAGE)
- err = msg.exception
- self.assertIn("test error message", str(err))
+ err = ctx.value
+ assert "test error message" in str(err)
diff --git a/tests/providers/google/cloud/log/test_gcs_task_handler.py b/tests/providers/google/cloud/log/test_gcs_task_handler.py
index 2f0a746b00eb0..dcf372066becc 100644
--- a/tests/providers/google/cloud/log/test_gcs_task_handler.py
+++ b/tests/providers/google/cloud/log/test_gcs_task_handler.py
@@ -65,7 +65,7 @@ def test_hook(self, mock_client, mock_creds):
mock_client.assert_called_once_with(
client_info=mock.ANY, credentials="TEST_CREDENTIALS", project="TEST_PROJECT_ID"
)
- self.assertEqual(mock_client.return_value, return_value)
+ assert mock_client.return_value == return_value
@conf_vars({("logging", "remote_log_conn_id"): "gcs_default"})
@mock.patch(
@@ -82,10 +82,8 @@ def test_should_read_logs_from_remote(self, mock_blob, mock_client, mock_creds):
"gs://bucket/remote/log/location/1.log", mock_client.return_value
)
- self.assertEqual(
- "*** Reading remote log from gs://bucket/remote/log/location/1.log.\nCONTENT\n", logs
- )
- self.assertEqual({"end_of_log": True}, metadata)
+ assert "*** Reading remote log from gs://bucket/remote/log/location/1.log.\nCONTENT\n" == logs
+ assert {"end_of_log": True} == metadata
@mock.patch(
"airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id",
@@ -99,12 +97,11 @@ def test_should_read_from_local(self, mock_blob, mock_client, mock_creds):
self.gcs_task_handler.set_context(self.ti)
log, metadata = self.gcs_task_handler._read(self.ti, self.ti.try_number)
- self.assertEqual(
- log,
- "*** Unable to read remote log from gs://bucket/remote/log/location/1.log\n*** "
- f"Failed to connect\n\n*** Reading local file: {self.local_log_location}/1.log\n",
+ assert (
+ log == "*** Unable to read remote log from gs://bucket/remote/log/location/1.log\n*** "
+ f"Failed to connect\n\n*** Reading local file: {self.local_log_location}/1.log\n"
)
- self.assertDictEqual(metadata, {"end_of_log": True})
+ assert metadata == {"end_of_log": True}
mock_blob.from_string.assert_called_once_with(
"gs://bucket/remote/log/location/1.log", mock_client.return_value
)
@@ -142,7 +139,7 @@ def test_write_to_remote_on_close(self, mock_blob, mock_client, mock_creds):
any_order=False,
)
mock_blob.from_string.return_value.upload_from_string(data="CONTENT\nMESSAGE\n")
- self.assertEqual(self.gcs_task_handler.closed, True)
+ assert self.gcs_task_handler.closed is True
@mock.patch(
"airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id",
@@ -169,13 +166,10 @@ def test_failed_write_to_remote_on_close(self, mock_blob, mock_client, mock_cred
with self.assertLogs(self.gcs_task_handler.log) as cm:
self.gcs_task_handler.close()
- self.assertEqual(
- cm.output,
- [
- 'ERROR:airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler:Could '
- 'not write logs to gs://bucket/remote/log/location/1.log: Failed to connect',
- ],
- )
+ assert cm.output == [
+ 'ERROR:airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler:Could '
+ 'not write logs to gs://bucket/remote/log/location/1.log: Failed to connect',
+ ]
mock_blob.assert_has_calls(
[
mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value),
diff --git a/tests/providers/google/cloud/log/test_gcs_task_handler_system.py b/tests/providers/google/cloud/log/test_gcs_task_handler_system.py
index aecced21d9031..4bc2a6224dea5 100644
--- a/tests/providers/google/cloud/log/test_gcs_task_handler_system.py
+++ b/tests/providers/google/cloud/log/test_gcs_task_handler_system.py
@@ -72,8 +72,8 @@ def test_should_read_logs(self, session):
AIRFLOW__CORE__DAGS_FOLDER=example_complex.__file__,
GOOGLE_APPLICATION_CREDENTIALS=resolve_full_gcp_key_path(GCP_GCS_KEY),
):
- self.assertEqual(0, subprocess.Popen(["airflow", "dags", "trigger", "example_complex"]).wait())
- self.assertEqual(0, subprocess.Popen(["airflow", "scheduler", "--num-runs", "1"]).wait())
+ assert 0 == subprocess.Popen(["airflow", "dags", "trigger", "example_complex"]).wait()
+ assert 0 == subprocess.Popen(["airflow", "scheduler", "--num-runs", "1"]).wait()
ti = session.query(TaskInstance).filter(TaskInstance.task_id == "create_entry_group").first()
dag = DagBag(dag_folder=example_complex.__file__).dags['example_complex']
@@ -96,4 +96,4 @@ def assert_remote_logs(self, expected_message, ti):
task_log_reader = TaskLogReader()
logs = "\n".join(task_log_reader.read_log_stream(ti, try_number=None, metadata={}))
- self.assertIn(expected_message, logs)
+ assert expected_message in logs
diff --git a/tests/providers/google/cloud/log/test_stackdriver_task_handler.py b/tests/providers/google/cloud/log/test_stackdriver_task_handler.py
index 21c761dbe3cbd..4159e9e0f4e54 100644
--- a/tests/providers/google/cloud/log/test_stackdriver_task_handler.py
+++ b/tests/providers/google/cloud/log/test_stackdriver_task_handler.py
@@ -135,8 +135,8 @@ def test_should_read_logs_for_all_try(self, mock_client, mock_get_creds_and_proj
'labels.execution_date="2016-01-01T00:00:00+00:00"',
page_token=None,
)
- self.assertEqual(['MSG1\nMSG2'], logs)
- self.assertEqual([{'end_of_log': True}], metadata)
+ assert ['MSG1\nMSG2'] == logs
+ assert [{'end_of_log': True}] == metadata
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@mock.patch(
@@ -156,8 +156,8 @@ def test_should_read_logs_for_task_with_quote(self, mock_client, mock_get_creds_
'labels.execution_date="2016-01-01T00:00:00+00:00"',
page_token=None,
)
- self.assertEqual(['MSG1\nMSG2'], logs)
- self.assertEqual([{'end_of_log': True}], metadata)
+ assert ['MSG1\nMSG2'] == logs
+ assert [{'end_of_log': True}] == metadata
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@mock.patch(
@@ -178,8 +178,8 @@ def test_should_read_logs_for_single_try(self, mock_client, mock_get_creds_and_p
'labels.try_number="3"',
page_token=None,
)
- self.assertEqual(['MSG1\nMSG2'], logs)
- self.assertEqual([{'end_of_log': True}], metadata)
+ assert ['MSG1\nMSG2'] == logs
+ assert [{'end_of_log': True}] == metadata
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client')
@@ -191,14 +191,14 @@ def test_should_read_logs_with_pagination(self, mock_client, mock_get_creds_and_
mock_get_creds_and_project_id.return_value = ('creds', 'project_id')
logs, metadata1 = self.stackdriver_task_handler.read(self.ti, 3)
mock_client.return_value.list_entries.assert_called_once_with(filter_=mock.ANY, page_token=None)
- self.assertEqual(['MSG1\nMSG2'], logs)
- self.assertEqual([{'end_of_log': False, 'next_page_token': 'TOKEN1'}], metadata1)
+ assert ['MSG1\nMSG2'] == logs
+ assert [{'end_of_log': False, 'next_page_token': 'TOKEN1'}] == metadata1
mock_client.return_value.list_entries.return_value.next_page_token = None
logs, metadata2 = self.stackdriver_task_handler.read(self.ti, 3, metadata1[0])
mock_client.return_value.list_entries.assert_called_with(filter_=mock.ANY, page_token="TOKEN1")
- self.assertEqual(['MSG3\nMSG4'], logs)
- self.assertEqual([{'end_of_log': True}], metadata2)
+ assert ['MSG3\nMSG4'] == logs
+ assert [{'end_of_log': True}] == metadata2
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client')
@@ -211,8 +211,8 @@ def test_should_read_logs_with_download(self, mock_client, mock_get_creds_and_pr
logs, metadata1 = self.stackdriver_task_handler.read(self.ti, 3, {'download_logs': True})
- self.assertEqual(['MSG1\nMSG2\nMSG3\nMSG4'], logs)
- self.assertEqual([{'end_of_log': True}], metadata1)
+ assert ['MSG1\nMSG2\nMSG3\nMSG4'] == logs
+ assert [{'end_of_log': True}] == metadata1
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@mock.patch(
@@ -250,8 +250,8 @@ def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_cred
'labels.execution_date="2016-01-01T00:00:00+00:00"',
page_token=None,
)
- self.assertEqual(['TEXT\nTEXT'], logs)
- self.assertEqual([{'end_of_log': True}], metadata)
+ assert ['TEXT\nTEXT'] == logs
+ assert [{'end_of_log': True}] == metadata
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client')
@@ -275,7 +275,7 @@ def test_should_use_credentials(self, mock_client, mock_get_creds_and_project_id
),
)
mock_client.assert_called_once_with(credentials='creds', client_info=mock.ANY, project="project_id")
- self.assertEqual(mock_client.return_value, client)
+ assert mock_client.return_value == client
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client')
@@ -291,11 +291,11 @@ def test_should_return_valid_external_url(self, mock_client, mock_get_creds_and_
parsed_url = urlparse(url)
parsed_qs = parse_qs(parsed_url.query)
- self.assertEqual('https', parsed_url.scheme)
- self.assertEqual('console.cloud.google.com', parsed_url.netloc)
- self.assertEqual('/logs/viewer', parsed_url.path)
- self.assertCountEqual(['project', 'interval', 'resource', 'advancedFilter'], parsed_qs.keys())
- self.assertIn('global', parsed_qs['resource'])
+ assert 'https' == parsed_url.scheme
+ assert 'console.cloud.google.com' == parsed_url.netloc
+ assert '/logs/viewer' == parsed_url.path
+ assert {'project', 'interval', 'resource', 'advancedFilter'} == set(parsed_qs.keys())
+ assert 'global' in parsed_qs['resource']
filter_params = parsed_qs['advancedFilter'][0].split('\n')
expected_filter = [
@@ -306,4 +306,4 @@ def test_should_return_valid_external_url(self, mock_client, mock_get_creds_and_
f'labels.execution_date="{self.ti.execution_date.isoformat()}"',
f'labels.try_number="{self.ti.try_number}"',
]
- self.assertCountEqual(expected_filter, filter_params)
+ assert set(expected_filter) == set(filter_params)
diff --git a/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py b/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py
index 6ed607ba5e2fe..46964bb0bfae0 100644
--- a/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py
+++ b/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py
@@ -58,8 +58,8 @@ def test_should_support_key_auth(self, session):
AIRFLOW__CORE__LOAD_EXAMPLES="false",
AIRFLOW__CORE__DAGS_FOLDER=example_complex.__file__,
):
- self.assertEqual(0, subprocess.Popen(["airflow", "dags", "trigger", "example_complex"]).wait())
- self.assertEqual(0, subprocess.Popen(["airflow", "scheduler", "--num-runs", "1"]).wait())
+ assert 0 == subprocess.Popen(["airflow", "dags", "trigger", "example_complex"]).wait()
+ assert 0 == subprocess.Popen(["airflow", "scheduler", "--num-runs", "1"]).wait()
ti = session.query(TaskInstance).filter(TaskInstance.task_id == "create_entry_group").first()
self.assert_remote_logs("INFO - Task exited with return code 0", ti)
@@ -74,8 +74,8 @@ def test_should_support_adc(self, session):
AIRFLOW__CORE__DAGS_FOLDER=example_complex.__file__,
GOOGLE_APPLICATION_CREDENTIALS=resolve_full_gcp_key_path(GCP_STACKDRIVER),
):
- self.assertEqual(0, subprocess.Popen(["airflow", "dags", "trigger", "example_complex"]).wait())
- self.assertEqual(0, subprocess.Popen(["airflow", "scheduler", "--num-runs", "1"]).wait())
+ assert 0 == subprocess.Popen(["airflow", "dags", "trigger", "example_complex"]).wait()
+ assert 0 == subprocess.Popen(["airflow", "scheduler", "--num-runs", "1"]).wait()
ti = session.query(TaskInstance).filter(TaskInstance.task_id == "create_entry_group").first()
self.assert_remote_logs("INFO - Task exited with return code 0", ti)
@@ -94,4 +94,4 @@ def assert_remote_logs(self, expected_message, ti):
task_log_reader = TaskLogReader()
logs = "\n".join(task_log_reader.read_log_stream(ti, try_number=None, metadata={}))
- self.assertIn(expected_message, logs)
+ assert expected_message in logs
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py
index 3fa01f88273e1..b41d31a97731a 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -427,7 +427,7 @@ def test_execute_bad_type(self, mock_hook):
cluster_fields=None,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
operator.execute(MagicMock())
@mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook')
@@ -460,10 +460,10 @@ def test_bigquery_operator_defaults(self, mock_hook):
cluster_fields=None,
encryption_configuration=None,
)
- self.assertTrue(isinstance(operator.sql, str))
+ assert isinstance(operator.sql, str)
ti = TaskInstance(task=operator, execution_date=DEFAULT_DATE)
ti.render_templates()
- self.assertTrue(isinstance(ti.task.sql, str))
+ assert isinstance(ti.task.sql, str)
def test_bigquery_operator_extra_serialized_field_when_single_query(self):
with self.dag:
@@ -472,35 +472,34 @@ def test_bigquery_operator_extra_serialized_field_when_single_query(self):
sql='SELECT * FROM test_table',
)
serialized_dag = SerializedDAG.to_dict(self.dag)
- self.assertIn("sql", serialized_dag["dag"]["tasks"][0])
+ assert "sql" in serialized_dag["dag"]["tasks"][0]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict[TASK_ID]
- self.assertEqual(getattr(simple_task, "sql"), 'SELECT * FROM test_table')
+ assert getattr(simple_task, "sql") == 'SELECT * FROM test_table'
#########################################################
# Verify Operator Links work with Serialized Operator
#########################################################
# Check Serialized version of operator link
- self.assertEqual(
- serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
- [{'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink': {}}],
- )
+ assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ {'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink': {}}
+ ]
# Check DeSerialized version of operator link
- self.assertIsInstance(list(simple_task.operator_extra_links)[0], BigQueryConsoleLink)
+ assert isinstance(list(simple_task.operator_extra_links)[0], BigQueryConsoleLink)
ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE)
ti.xcom_push('job_id', 12345)
# check for positive case
url = simple_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name)
- self.assertEqual(url, 'https://console.cloud.google.com/bigquery?j=12345')
+ assert url == 'https://console.cloud.google.com/bigquery?j=12345'
# check for negative case
url2 = simple_task.get_extra_links(datetime(2017, 1, 2), BigQueryConsoleLink.name)
- self.assertEqual(url2, '')
+ assert url2 == ''
def test_bigquery_operator_extra_serialized_field_when_multiple_queries(self):
with self.dag:
@@ -509,54 +508,37 @@ def test_bigquery_operator_extra_serialized_field_when_multiple_queries(self):
sql=['SELECT * FROM test_table', 'SELECT * FROM test_table2'],
)
serialized_dag = SerializedDAG.to_dict(self.dag)
- self.assertIn("sql", serialized_dag["dag"]["tasks"][0])
+ assert "sql" in serialized_dag["dag"]["tasks"][0]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict[TASK_ID]
- self.assertEqual(
- getattr(simple_task, "sql"), ['SELECT * FROM test_table', 'SELECT * FROM test_table2']
- )
+ assert getattr(simple_task, "sql") == ['SELECT * FROM test_table', 'SELECT * FROM test_table2']
#########################################################
# Verify Operator Links work with Serialized Operator
#########################################################
# Check Serialized version of operator link
- self.assertEqual(
- serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
- [
- {
- 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': {
- 'index': 0
- }
- },
- {
- 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': {
- 'index': 1
- }
- },
- ],
- )
+ assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ {'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': {'index': 0}},
+ {'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': {'index': 1}},
+ ]
# Check DeSerialized version of operator link
- self.assertIsInstance(list(simple_task.operator_extra_links)[0], BigQueryConsoleIndexableLink)
+ assert isinstance(list(simple_task.operator_extra_links)[0], BigQueryConsoleIndexableLink)
ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE)
job_id = ['123', '45']
ti.xcom_push(key='job_id', value=job_id)
- self.assertEqual(
- {'BigQuery Console #1', 'BigQuery Console #2'}, simple_task.operator_extra_link_dict.keys()
- )
+ assert {'BigQuery Console #1', 'BigQuery Console #2'} == simple_task.operator_extra_link_dict.keys()
- self.assertEqual(
- 'https://console.cloud.google.com/bigquery?j=123',
- simple_task.get_extra_links(DEFAULT_DATE, 'BigQuery Console #1'),
+ assert 'https://console.cloud.google.com/bigquery?j=123' == simple_task.get_extra_links(
+ DEFAULT_DATE, 'BigQuery Console #1'
)
- self.assertEqual(
- 'https://console.cloud.google.com/bigquery?j=45',
- simple_task.get_extra_links(DEFAULT_DATE, 'BigQuery Console #2'),
+ assert 'https://console.cloud.google.com/bigquery?j=45' == simple_task.get_extra_links(
+ DEFAULT_DATE, 'BigQuery Console #2'
)
@provide_session
@@ -570,10 +552,7 @@ def test_bigquery_operator_extra_link_when_missing_job_id(self, mock_hook, sessi
self.dag.clear()
session.query(XCom).delete()
- self.assertEqual(
- '',
- bigquery_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name),
- )
+ assert '' == bigquery_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name)
@provide_session
@mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook')
@@ -594,15 +573,11 @@ def test_bigquery_operator_extra_link_when_single_query(self, mock_hook, session
job_id = '12345'
ti.xcom_push(key='job_id', value=job_id)
- self.assertEqual(
- f'https://console.cloud.google.com/bigquery?j={job_id}',
- bigquery_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name),
+ assert f'https://console.cloud.google.com/bigquery?j={job_id}' == bigquery_task.get_extra_links(
+ DEFAULT_DATE, BigQueryConsoleLink.name
)
- self.assertEqual(
- '',
- bigquery_task.get_extra_links(datetime(2019, 1, 1), BigQueryConsoleLink.name),
- )
+ assert '' == bigquery_task.get_extra_links(datetime(2019, 1, 1), BigQueryConsoleLink.name)
@provide_session
@mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook')
@@ -623,18 +598,14 @@ def test_bigquery_operator_extra_link_when_multiple_query(self, mock_hook, sessi
job_id = ['123', '45']
ti.xcom_push(key='job_id', value=job_id)
- self.assertEqual(
- {'BigQuery Console #1', 'BigQuery Console #2'}, bigquery_task.operator_extra_link_dict.keys()
- )
+ assert {'BigQuery Console #1', 'BigQuery Console #2'} == bigquery_task.operator_extra_link_dict.keys()
- self.assertEqual(
- 'https://console.cloud.google.com/bigquery?j=123',
- bigquery_task.get_extra_links(DEFAULT_DATE, 'BigQuery Console #1'),
+ assert 'https://console.cloud.google.com/bigquery?j=123' == bigquery_task.get_extra_links(
+ DEFAULT_DATE, 'BigQuery Console #1'
)
- self.assertEqual(
- 'https://console.cloud.google.com/bigquery?j=45',
- bigquery_task.get_extra_links(DEFAULT_DATE, 'BigQuery Console #2'),
+ assert 'https://console.cloud.google.com/bigquery?j=45' == bigquery_task.get_extra_links(
+ DEFAULT_DATE, 'BigQuery Console #2'
)
@@ -716,12 +687,15 @@ class TestBigQueryConnIdDeprecationWarning(unittest.TestCase):
)
def test_bigquery_conn_id_deprecation_warning(self, operator_class, kwargs):
bigquery_conn_id = 'google_cloud_default'
- with self.assertWarnsRegex(
+ with pytest.warns(
DeprecationWarning,
- "The bigquery_conn_id parameter has been deprecated. You should pass the gcp_conn_id parameter.",
+ match=(
+ "The bigquery_conn_id parameter has been deprecated. "
+ "You should pass the gcp_conn_id parameter."
+ ),
):
operator = operator_class(bigquery_conn_id=bigquery_conn_id, **kwargs)
- self.assertEqual(bigquery_conn_id, operator.gcp_conn_id)
+ assert bigquery_conn_id == operator.gcp_conn_id
class TestBigQueryUpsertTableOperator(unittest.TestCase):
diff --git a/tests/providers/google/cloud/operators/test_bigtable.py b/tests/providers/google/cloud/operators/test_bigtable.py
index 5e95ec956d1aa..30016e5f43961 100644
--- a/tests/providers/google/cloud/operators/test_bigtable.py
+++ b/tests/providers/google/cloud/operators/test_bigtable.py
@@ -21,6 +21,7 @@
from unittest import mock
import google.api_core.exceptions
+import pytest
from google.cloud.bigtable.column_family import MaxVersionsGCRule
from google.cloud.bigtable.instance import Instance
from google.cloud.bigtable_admin_v2 import enums
@@ -69,7 +70,7 @@ class TestBigtableInstanceCreate(unittest.TestCase):
def test_empty_attribute(
self, missing_attribute, project_id, instance_id, main_cluster_id, main_cluster_zone, mock_hook
):
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
BigtableCreateInstanceOperator(
project_id=project_id,
instance_id=instance_id,
@@ -78,8 +79,8 @@ def test_empty_attribute(
task_id="id",
gcp_conn_id=GCP_CONN_ID,
)
- err = e.exception
- self.assertEqual(str(err), f'Empty parameter: {missing_attribute}')
+ err = ctx.value
+ assert str(err) == f'Empty parameter: {missing_attribute}'
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook')
@@ -140,7 +141,7 @@ def test_different_error_reraised(self, mock_hook):
side_effect=google.api_core.exceptions.GoogleAPICallError('error')
)
- with self.assertRaises(google.api_core.exceptions.GoogleAPICallError):
+ with pytest.raises(google.api_core.exceptions.GoogleAPICallError):
op.execute(None)
mock_hook.assert_called_once_with(
@@ -291,7 +292,7 @@ def test_update_execute_empty_project_id(self, mock_hook):
)
@mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook')
def test_empty_attribute(self, missing_attribute, project_id, instance_id, mock_hook):
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
BigtableUpdateInstanceOperator(
project_id=project_id,
instance_id=instance_id,
@@ -300,15 +301,15 @@ def test_empty_attribute(self, missing_attribute, project_id, instance_id, mock_
instance_labels=INSTANCE_LABELS,
task_id="id",
)
- err = e.exception
- self.assertEqual(str(err), f'Empty parameter: {missing_attribute}')
+ err = ctx.value
+ assert str(err) == f'Empty parameter: {missing_attribute}'
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook')
def test_update_instance_that_doesnt_exists(self, mock_hook):
mock_hook.return_value.get_instance.return_value = None
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
op = BigtableUpdateInstanceOperator(
project_id=PROJECT_ID,
instance_id=INSTANCE_ID,
@@ -321,8 +322,8 @@ def test_update_instance_that_doesnt_exists(self, mock_hook):
)
op.execute(None)
- err = e.exception
- self.assertEqual(str(err), f"Dependency: instance '{INSTANCE_ID}' does not exist.")
+ err = ctx.value
+ assert str(err) == f"Dependency: instance '{INSTANCE_ID}' does not exist."
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
@@ -334,7 +335,7 @@ def test_update_instance_that_doesnt_exists(self, mock_hook):
def test_update_instance_that_doesnt_exists_empty_project_id(self, mock_hook):
mock_hook.return_value.get_instance.return_value = None
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
op = BigtableUpdateInstanceOperator(
instance_id=INSTANCE_ID,
instance_display_name=INSTANCE_DISPLAY_NAME,
@@ -346,8 +347,8 @@ def test_update_instance_that_doesnt_exists_empty_project_id(self, mock_hook):
)
op.execute(None)
- err = e.exception
- self.assertEqual(str(err), f"Dependency: instance '{INSTANCE_ID}' does not exist.")
+ err = ctx.value
+ assert str(err) == f"Dependency: instance '{INSTANCE_ID}' does not exist."
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
@@ -371,7 +372,7 @@ def test_different_error_reraised(self, mock_hook):
side_effect=google.api_core.exceptions.GoogleAPICallError('error')
)
- with self.assertRaises(google.api_core.exceptions.GoogleAPICallError):
+ with pytest.raises(google.api_core.exceptions.GoogleAPICallError):
op.execute(None)
mock_hook.assert_called_once_with(
@@ -399,7 +400,7 @@ class TestBigtableClusterUpdate(unittest.TestCase):
)
@mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook')
def test_empty_attribute(self, missing_attribute, project_id, instance_id, cluster_id, nodes, mock_hook):
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
BigtableUpdateClusterOperator(
project_id=project_id,
instance_id=instance_id,
@@ -408,15 +409,15 @@ def test_empty_attribute(self, missing_attribute, project_id, instance_id, clust
task_id="id",
gcp_conn_id=GCP_CONN_ID,
)
- err = e.exception
- self.assertEqual(str(err), f'Empty parameter: {missing_attribute}')
+ err = ctx.value
+ assert str(err) == f'Empty parameter: {missing_attribute}'
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook')
def test_updating_cluster_but_instance_does_not_exists(self, mock_hook):
mock_hook.return_value.get_instance.return_value = None
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
op = BigtableUpdateClusterOperator(
project_id=PROJECT_ID,
instance_id=INSTANCE_ID,
@@ -428,8 +429,8 @@ def test_updating_cluster_but_instance_does_not_exists(self, mock_hook):
)
op.execute(None)
- err = e.exception
- self.assertEqual(str(err), f"Dependency: instance '{INSTANCE_ID}' does not exist.")
+ err = ctx.value
+ assert str(err) == f"Dependency: instance '{INSTANCE_ID}' does not exist."
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -440,7 +441,7 @@ def test_updating_cluster_but_instance_does_not_exists(self, mock_hook):
def test_updating_cluster_but_instance_does_not_exists_empty_project_id(self, mock_hook):
mock_hook.return_value.get_instance.return_value = None
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
op = BigtableUpdateClusterOperator(
instance_id=INSTANCE_ID,
cluster_id=CLUSTER_ID,
@@ -451,8 +452,8 @@ def test_updating_cluster_but_instance_does_not_exists_empty_project_id(self, mo
)
op.execute(None)
- err = e.exception
- self.assertEqual(str(err), f"Dependency: instance '{INSTANCE_ID}' does not exist.")
+ err = ctx.value
+ assert str(err) == f"Dependency: instance '{INSTANCE_ID}' does not exist."
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -466,7 +467,7 @@ def test_updating_cluster_that_does_not_exists(self, mock_hook):
side_effect=google.api_core.exceptions.NotFound("Cluster not found.")
)
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
op = BigtableUpdateClusterOperator(
project_id=PROJECT_ID,
instance_id=INSTANCE_ID,
@@ -478,11 +479,8 @@ def test_updating_cluster_that_does_not_exists(self, mock_hook):
)
op.execute(None)
- err = e.exception
- self.assertEqual(
- str(err),
- f"Dependency: cluster '{CLUSTER_ID}' does not exist for instance '{INSTANCE_ID}'.",
- )
+ err = ctx.value
+ assert str(err) == f"Dependency: cluster '{CLUSTER_ID}' does not exist for instance '{INSTANCE_ID}'."
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -498,7 +496,7 @@ def test_updating_cluster_that_does_not_exists_empty_project_id(self, mock_hook)
side_effect=google.api_core.exceptions.NotFound("Cluster not found.")
)
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
op = BigtableUpdateClusterOperator(
instance_id=INSTANCE_ID,
cluster_id=CLUSTER_ID,
@@ -509,11 +507,8 @@ def test_updating_cluster_that_does_not_exists_empty_project_id(self, mock_hook)
)
op.execute(None)
- err = e.exception
- self.assertEqual(
- str(err),
- f"Dependency: cluster '{CLUSTER_ID}' does not exist for instance '{INSTANCE_ID}'.",
- )
+ err = ctx.value
+ assert str(err) == f"Dependency: cluster '{CLUSTER_ID}' does not exist for instance '{INSTANCE_ID}'."
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -538,7 +533,7 @@ def test_different_error_reraised(self, mock_hook):
side_effect=google.api_core.exceptions.GoogleAPICallError('error')
)
- with self.assertRaises(google.api_core.exceptions.GoogleAPICallError):
+ with pytest.raises(google.api_core.exceptions.GoogleAPICallError):
op.execute(None)
mock_hook.assert_called_once_with(
@@ -594,10 +589,10 @@ def test_delete_execute_empty_project_id(self, mock_hook):
)
@mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook')
def test_empty_attribute(self, missing_attribute, project_id, instance_id, mock_hook):
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
BigtableDeleteInstanceOperator(project_id=project_id, instance_id=instance_id, task_id="id")
- err = e.exception
- self.assertEqual(str(err), f'Empty parameter: {missing_attribute}')
+ err = ctx.value
+ assert str(err) == f'Empty parameter: {missing_attribute}'
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook')
@@ -654,7 +649,7 @@ def test_different_error_reraised(self, mock_hook):
side_effect=google.api_core.exceptions.GoogleAPICallError('error')
)
- with self.assertRaises(google.api_core.exceptions.GoogleAPICallError):
+ with pytest.raises(google.api_core.exceptions.GoogleAPICallError):
op.execute(None)
mock_hook.assert_called_once_with(
@@ -695,7 +690,7 @@ def test_delete_execute(self, mock_hook):
)
@mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook')
def test_empty_attribute(self, missing_attribute, project_id, instance_id, table_id, mock_hook):
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
BigtableDeleteTableOperator(
project_id=project_id,
instance_id=instance_id,
@@ -703,8 +698,8 @@ def test_empty_attribute(self, missing_attribute, project_id, instance_id, table
task_id="id",
gcp_conn_id=GCP_CONN_ID,
)
- err = e.exception
- self.assertEqual(str(err), f'Empty parameter: {missing_attribute}')
+ err = ctx.value
+ assert str(err) == f'Empty parameter: {missing_attribute}'
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook')
@@ -764,10 +759,10 @@ def test_deleting_table_when_instance_doesnt_exists(self, mock_hook):
)
mock_hook.return_value.get_instance.return_value = None
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
op.execute(None)
- err = e.exception
- self.assertEqual(str(err), f"Dependency: instance '{INSTANCE_ID}' does not exist.")
+ err = ctx.value
+ assert str(err) == f"Dependency: instance '{INSTANCE_ID}' does not exist."
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -788,7 +783,7 @@ def test_different_error_reraised(self, mock_hook):
side_effect=google.api_core.exceptions.GoogleAPICallError('error')
)
- with self.assertRaises(google.api_core.exceptions.GoogleAPICallError):
+ with pytest.raises(google.api_core.exceptions.GoogleAPICallError):
op.execute(None)
mock_hook.assert_called_once_with(
@@ -835,7 +830,7 @@ def test_create_execute(self, mock_hook):
)
@mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook')
def test_empty_attribute(self, missing_attribute, project_id, instance_id, table_id, mock_hook):
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
BigtableCreateTableOperator(
project_id=project_id,
instance_id=instance_id,
@@ -843,8 +838,8 @@ def test_empty_attribute(self, missing_attribute, project_id, instance_id, table
task_id="id",
gcp_conn_id=GCP_CONN_ID,
)
- err = e.exception
- self.assertEqual(str(err), f'Empty parameter: {missing_attribute}')
+ err = ctx.value
+ assert str(err) == f'Empty parameter: {missing_attribute}'
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.bigtable.BigtableHook')
@@ -860,13 +855,10 @@ def test_instance_not_exists(self, mock_hook):
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.get_instance.return_value = None
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
op.execute(None)
- err = e.exception
- self.assertEqual(
- str(err),
- f"Dependency: instance '{INSTANCE_ID}' does not exist in project '{PROJECT_ID}'.",
- )
+ err = ctx.value
+ assert str(err) == f"Dependency: instance '{INSTANCE_ID}' does not exist in project '{PROJECT_ID}'."
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -951,10 +943,10 @@ def test_creating_table_that_exists_with_different_column_families_ids_in_the_ta
side_effect=google.api_core.exceptions.AlreadyExists("Table already exists.")
)
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
op.execute(None)
- err = e.exception
- self.assertEqual(str(err), f"Table '{TABLE_ID}' already exists with different Column Families.")
+ err = ctx.value
+ assert str(err) == f"Table '{TABLE_ID}' already exists with different Column Families."
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -981,10 +973,10 @@ def test_creating_table_that_exists_with_different_column_families_gc_rule_in__t
side_effect=google.api_core.exceptions.AlreadyExists("Table already exists.")
)
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
op.execute(None)
- err = e.exception
- self.assertEqual(str(err), f"Table '{TABLE_ID}' already exists with different Column Families.")
+ err = ctx.value
+ assert str(err) == f"Table '{TABLE_ID}' already exists with different Column Families."
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
diff --git a/tests/providers/google/cloud/operators/test_cloud_build.py b/tests/providers/google/cloud/operators/test_cloud_build.py
index 970ee0606ef0f..3d0e0f009bd9c 100644
--- a/tests/providers/google/cloud/operators/test_cloud_build.py
+++ b/tests/providers/google/cloud/operators/test_cloud_build.py
@@ -22,6 +22,7 @@
from datetime import datetime
from unittest import TestCase, mock
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -41,7 +42,7 @@
class TestBuildProcessor(TestCase):
def test_verify_source(self):
- with self.assertRaisesRegex(AirflowException, "The source could not be determined."):
+ with pytest.raises(AirflowException, match="The source could not be determined."):
BuildProcessor(body={"source": {"storageSource": {}, "repoSource": {}}}).process_body()
@parameterized.expand(
@@ -63,7 +64,7 @@ def test_verify_source(self):
def test_convert_repo_url_to_dict_valid(self, url, expected_dict):
body = {"source": {"repoSource": url}}
body = BuildProcessor(body=body).process_body()
- self.assertEqual(body["source"]["repoSource"], expected_dict)
+ assert body["source"]["repoSource"] == expected_dict
@parameterized.expand(
[
@@ -76,7 +77,7 @@ def test_convert_repo_url_to_dict_valid(self, url, expected_dict):
)
def test_convert_repo_url_to_storage_dict_invalid(self, url):
body = {"source": {"repoSource": url}}
- with self.assertRaisesRegex(AirflowException, "Invalid URL."):
+ with pytest.raises(AirflowException, match="Invalid URL."):
BuildProcessor(body=body).process_body()
@parameterized.expand(
@@ -94,14 +95,14 @@ def test_convert_repo_url_to_storage_dict_invalid(self, url):
def test_convert_storage_url_to_dict_valid(self, url, expected_dict):
body = {"source": {"storageSource": url}}
body = BuildProcessor(body=body).process_body()
- self.assertEqual(body["source"]["storageSource"], expected_dict)
+ assert body["source"]["storageSource"] == expected_dict
@parameterized.expand(
[("///object",), ("gsXXa:///object",), ("gs://bucket-name/",), ("gs://bucket-name",)]
)
def test_convert_storage_url_to_dict_invalid(self, url):
body = {"source": {"storageSource": url}}
- with self.assertRaisesRegex(AirflowException, "Invalid URL."):
+ with pytest.raises(AirflowException, match="Invalid URL."):
BuildProcessor(body=body).process_body()
@parameterized.expand([("storageSource",), ("repoSource",)])
@@ -110,7 +111,7 @@ def test_do_nothing(self, source_key):
expected_body = deepcopy(body)
BuildProcessor(body=body).process_body()
- self.assertEqual(body, expected_body)
+ assert body == expected_body
class TestGcpCloudBuildCreateBuildOperator(TestCase):
@@ -121,11 +122,11 @@ def test_minimal_green_path(self, mock_hook):
body=TEST_CREATE_BODY, project_id=TEST_PROJECT_ID, task_id="task-id"
)
result = operator.execute({})
- self.assertIs(result, TEST_CREATE_BODY)
+ assert result is TEST_CREATE_BODY
@parameterized.expand([({},), (None,)])
def test_missing_input(self, body):
- with self.assertRaisesRegex(AirflowException, "The required parameter 'body' is missing"):
+ with pytest.raises(AirflowException, match="The required parameter 'body' is missing"):
CloudBuildCreateBuildOperator(body=body, project_id=TEST_PROJECT_ID, task_id="task-id")
@mock.patch("airflow.providers.google.cloud.operators.cloud_build.CloudBuildHook")
@@ -206,7 +207,7 @@ def test_repo_source_replace(self, hook_mock):
hook_mock.return_value.create_build.assert_called_once_with(
body=expected_body, project_id=TEST_PROJECT_ID
)
- self.assertEqual(return_value, TEST_CREATE_BODY)
+ assert return_value == TEST_CREATE_BODY
def test_load_templated_yaml(self):
dag = DAG(dag_id='example_cloudbuild_operator', start_date=TEST_DEFAULT_DATE)
@@ -227,4 +228,4 @@ def test_load_templated_yaml(self):
ti = TaskInstance(operator, TEST_DEFAULT_DATE)
ti.render_templates()
expected_body = {'steps': [{'name': 'ubuntu', 'args': ['echo', 'Hello airflow!']}]}
- self.assertEqual(expected_body, operator.body)
+ assert expected_body == operator.body
diff --git a/tests/providers/google/cloud/operators/test_cloud_sql.py b/tests/providers/google/cloud/operators/test_cloud_sql.py
index a159c7497197d..3a224788f288f 100644
--- a/tests/providers/google/cloud/operators/test_cloud_sql.py
+++ b/tests/providers/google/cloud/operators/test_cloud_sql.py
@@ -22,6 +22,7 @@
import unittest
from unittest import mock
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -159,7 +160,7 @@ def test_instance_create(self, mock_hook, _check_if_instance_exists):
mock_hook.return_value.create_instance.assert_called_once_with(
project_id=PROJECT_ID, body=CREATE_BODY
)
- self.assertIsNone(result)
+ assert result is None
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_sql"
@@ -179,7 +180,7 @@ def test_instance_create_missing_project_id(self, mock_hook, _check_if_instance_
impersonation_chain=None,
)
mock_hook.return_value.create_instance.assert_called_once_with(project_id=None, body=CREATE_BODY)
- self.assertIsNone(result)
+ assert result is None
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_sql"
@@ -201,39 +202,39 @@ def test_instance_create_idempotent(self, mock_hook, _check_if_instance_exists):
impersonation_chain=None,
)
mock_hook.return_value.create_instance.assert_not_called()
- self.assertIsNone(result)
+ assert result is None
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
def test_create_should_throw_ex_when_empty_project_id(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudSQLCreateInstanceOperator(
project_id="", body=CREATE_BODY, instance=INSTANCE_NAME, task_id="id"
)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'project_id' is empty", str(err))
+ err = ctx.value
+ assert "The required parameter 'project_id' is empty" in str(err)
mock_hook.assert_not_called()
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
def test_create_should_throw_ex_when_empty_body(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudSQLCreateInstanceOperator(
project_id=PROJECT_ID, body={}, instance=INSTANCE_NAME, task_id="id"
)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'body' is empty", str(err))
+ err = ctx.value
+ assert "The required parameter 'body' is empty" in str(err)
mock_hook.assert_not_called()
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
def test_create_should_throw_ex_when_empty_instance(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudSQLCreateInstanceOperator(
project_id=PROJECT_ID, body=CREATE_BODY, instance="", task_id="id"
)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'instance' is empty", str(err))
+ err = ctx.value
+ assert "The required parameter 'instance' is empty" in str(err)
mock_hook.assert_not_called()
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
@@ -248,16 +249,15 @@ def test_create_should_validate_list_type(self, mock_hook):
},
},
}
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudSQLCreateInstanceOperator(
project_id=PROJECT_ID, body=wrong_list_type_body, instance=INSTANCE_NAME, task_id="id"
)
op.execute(None)
- err = cm.exception
- self.assertIn(
+ err = ctx.value
+ assert (
"The field 'settings.ipConfiguration.authorizedNetworks' "
- "should be of list type according to the specification",
- str(err),
+ "should be of list type according to the specification" in str(err)
)
mock_hook.assert_called_once_with(
api_version="v1beta4",
@@ -274,13 +274,13 @@ def test_create_should_validate_non_empty_fields(self, mock_hook):
# Testing if the validation catches this.
},
}
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudSQLCreateInstanceOperator(
project_id=PROJECT_ID, body=empty_tier_body, instance=INSTANCE_NAME, task_id="id"
)
op.execute(None)
- err = cm.exception
- self.assertIn("The body field 'settings.tier' can't be empty. Please provide a value.", str(err))
+ err = ctx.value
+ assert "The body field 'settings.tier' can't be empty. Please provide a value." in str(err)
mock_hook.assert_called_once_with(
api_version="v1beta4",
gcp_conn_id="google_cloud_default",
@@ -302,7 +302,7 @@ def test_instance_patch(self, mock_hook):
mock_hook.return_value.patch_instance.assert_called_once_with(
project_id=PROJECT_ID, body=PATCH_BODY, instance=INSTANCE_NAME
)
- self.assertTrue(result)
+ assert result
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
def test_instance_patch_missing_project_id(self, mock_hook):
@@ -317,7 +317,7 @@ def test_instance_patch_missing_project_id(self, mock_hook):
mock_hook.return_value.patch_instance.assert_called_once_with(
project_id=None, body=PATCH_BODY, instance=INSTANCE_NAME
)
- self.assertTrue(result)
+ assert result
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_sql"
@@ -326,13 +326,13 @@ def test_instance_patch_missing_project_id(self, mock_hook):
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
def test_instance_patch_should_bubble_up_ex_if_not_exists(self, mock_hook, _check_if_instance_exists):
_check_if_instance_exists.return_value = False
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudSQLInstancePatchOperator(
project_id=PROJECT_ID, body=PATCH_BODY, instance=INSTANCE_NAME, task_id="id"
)
op.execute(None)
- err = cm.exception
- self.assertIn('specify another instance to patch', str(err))
+ err = ctx.value
+ assert 'specify another instance to patch' in str(err)
mock_hook.assert_called_once_with(
api_version="v1beta4",
gcp_conn_id="google_cloud_default",
@@ -349,7 +349,7 @@ def test_instance_delete(self, mock_hook, _check_if_instance_exists):
_check_if_instance_exists.return_value = True
op = CloudSQLDeleteInstanceOperator(project_id=PROJECT_ID, instance=INSTANCE_NAME, task_id="id")
result = op.execute(None)
- self.assertTrue(result)
+ assert result
mock_hook.assert_called_once_with(
api_version="v1beta4",
gcp_conn_id="google_cloud_default",
@@ -368,7 +368,7 @@ def test_instance_delete_missing_project_id(self, mock_hook, _check_if_instance_
_check_if_instance_exists.return_value = True
op = CloudSQLDeleteInstanceOperator(instance=INSTANCE_NAME, task_id="id")
result = op.execute(None)
- self.assertTrue(result)
+ assert result
mock_hook.assert_called_once_with(
api_version="v1beta4",
gcp_conn_id="google_cloud_default",
@@ -389,7 +389,7 @@ def test_instance_delete_should_abort_and_succeed_if_not_exists(
_check_if_instance_exists.return_value = False
op = CloudSQLDeleteInstanceOperator(project_id=PROJECT_ID, instance=INSTANCE_NAME, task_id="id")
result = op.execute(None)
- self.assertTrue(result)
+ assert result
mock_hook.assert_called_once_with(
api_version="v1beta4",
gcp_conn_id="google_cloud_default",
@@ -416,7 +416,7 @@ def test_instance_db_create(self, mock_hook, _check_if_db_exists):
mock_hook.return_value.create_database.assert_called_once_with(
project_id=PROJECT_ID, instance=INSTANCE_NAME, body=DATABASE_INSERT_BODY
)
- self.assertTrue(result)
+ assert result
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_sql"
@@ -437,7 +437,7 @@ def test_instance_db_create_missing_project_id(self, mock_hook, _check_if_db_exi
mock_hook.return_value.create_database.assert_called_once_with(
project_id=None, instance=INSTANCE_NAME, body=DATABASE_INSERT_BODY
)
- self.assertTrue(result)
+ assert result
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_sql"
@@ -450,7 +450,7 @@ def test_instance_db_create_should_abort_and_succeed_if_exists(self, mock_hook,
project_id=PROJECT_ID, instance=INSTANCE_NAME, body=DATABASE_INSERT_BODY, task_id="id"
)
result = op.execute(None)
- self.assertTrue(result)
+ assert result
mock_hook.assert_called_once_with(
api_version="v1beta4",
gcp_conn_id="google_cloud_default",
@@ -481,7 +481,7 @@ def test_instance_db_patch(self, mock_hook, _check_if_db_exists):
mock_hook.return_value.patch_database.assert_called_once_with(
project_id=PROJECT_ID, instance=INSTANCE_NAME, database=DB_NAME, body=DATABASE_PATCH_BODY
)
- self.assertTrue(result)
+ assert result
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_sql"
@@ -502,7 +502,7 @@ def test_instance_db_patch_missing_project_id(self, mock_hook, _check_if_db_exis
mock_hook.return_value.patch_database.assert_called_once_with(
project_id=None, instance=INSTANCE_NAME, database=DB_NAME, body=DATABASE_PATCH_BODY
)
- self.assertTrue(result)
+ assert result
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_sql"
@@ -511,7 +511,7 @@ def test_instance_db_patch_missing_project_id(self, mock_hook, _check_if_db_exis
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
def test_instance_db_patch_should_throw_ex_if_not_exists(self, mock_hook, _check_if_db_exists):
_check_if_db_exists.return_value = False
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudSQLPatchInstanceDatabaseOperator(
project_id=PROJECT_ID,
instance=INSTANCE_NAME,
@@ -520,9 +520,9 @@ def test_instance_db_patch_should_throw_ex_if_not_exists(self, mock_hook, _check
task_id="id",
)
op.execute(None)
- err = cm.exception
- self.assertIn("Cloud SQL instance with ID", str(err))
- self.assertIn("does not contain database", str(err))
+ err = ctx.value
+ assert "Cloud SQL instance with ID" in str(err)
+ assert "does not contain database" in str(err)
mock_hook.assert_called_once_with(
api_version="v1beta4",
gcp_conn_id="google_cloud_default",
@@ -532,7 +532,7 @@ def test_instance_db_patch_should_throw_ex_if_not_exists(self, mock_hook, _check
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
def test_instance_db_patch_should_throw_ex_when_empty_database(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudSQLPatchInstanceDatabaseOperator(
project_id=PROJECT_ID,
instance=INSTANCE_NAME,
@@ -541,8 +541,8 @@ def test_instance_db_patch_should_throw_ex_when_empty_database(self, mock_hook):
task_id="id",
)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'database' is empty", str(err))
+ err = ctx.value
+ assert "The required parameter 'database' is empty" in str(err)
mock_hook.assert_not_called()
mock_hook.return_value.patch_database.assert_not_called()
@@ -557,7 +557,7 @@ def test_instance_db_delete(self, mock_hook, _check_if_db_exists):
project_id=PROJECT_ID, instance=INSTANCE_NAME, database=DB_NAME, task_id="id"
)
result = op.execute(None)
- self.assertTrue(result)
+ assert result
mock_hook.assert_called_once_with(
api_version="v1beta4",
gcp_conn_id="google_cloud_default",
@@ -576,7 +576,7 @@ def test_instance_db_delete_missing_project_id(self, mock_hook, _check_if_db_exi
_check_if_db_exists.return_value = True
op = CloudSQLDeleteInstanceDatabaseOperator(instance=INSTANCE_NAME, database=DB_NAME, task_id="id")
result = op.execute(None)
- self.assertTrue(result)
+ assert result
mock_hook.assert_called_once_with(
api_version="v1beta4",
gcp_conn_id="google_cloud_default",
@@ -597,7 +597,7 @@ def test_instance_db_delete_should_abort_and_succeed_if_not_exists(self, mock_ho
project_id=PROJECT_ID, instance=INSTANCE_NAME, database=DB_NAME, task_id="id"
)
result = op.execute(None)
- self.assertTrue(result)
+ assert result
mock_hook.assert_called_once_with(
api_version="v1beta4",
gcp_conn_id="google_cloud_default",
@@ -620,7 +620,7 @@ def test_instance_export(self, mock_hook):
mock_hook.return_value.export_instance.assert_called_once_with(
project_id=PROJECT_ID, instance=INSTANCE_NAME, body=EXPORT_BODY
)
- self.assertTrue(result)
+ assert result
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
def test_instance_export_missing_project_id(self, mock_hook):
@@ -635,7 +635,7 @@ def test_instance_export_missing_project_id(self, mock_hook):
mock_hook.return_value.export_instance.assert_called_once_with(
project_id=None, instance=INSTANCE_NAME, body=EXPORT_BODY
)
- self.assertTrue(result)
+ assert result
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
def test_instance_import(self, mock_hook):
@@ -652,7 +652,7 @@ def test_instance_import(self, mock_hook):
mock_hook.return_value.import_instance.assert_called_once_with(
project_id=PROJECT_ID, instance=INSTANCE_NAME, body=IMPORT_BODY
)
- self.assertTrue(result)
+ assert result
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
def test_instance_import_missing_project_id(self, mock_hook):
@@ -667,7 +667,7 @@ def test_instance_import_missing_project_id(self, mock_hook):
mock_hook.return_value.import_instance.assert_called_once_with(
project_id=None, instance=INSTANCE_NAME, body=IMPORT_BODY
)
- self.assertTrue(result)
+ assert result
class TestCloudSqlQueryValidation(unittest.TestCase):
@@ -762,11 +762,11 @@ def test_create_operator_with_wrong_parameters(
)
)
self._setup_connections(get_connection, uri)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudSQLExecuteQueryOperator(sql=sql, task_id='task_id')
op.execute(None)
- err = cm.exception
- self.assertIn(message, str(err))
+ err = ctx.value
+ assert message in str(err)
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_create_operator_with_too_long_unix_socket_path(self, get_connection):
@@ -780,7 +780,7 @@ def test_create_operator_with_too_long_unix_socket_path(self, get_connection):
)
self._setup_connections(get_connection, uri)
operator = CloudSQLExecuteQueryOperator(sql=['SELECT * FROM TABLE'], task_id='task_id')
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
operator.execute(None)
- err = cm.exception
- self.assertIn("The UNIX socket path length cannot exceed", str(err))
+ err = ctx.value
+ assert "The UNIX socket path length cannot exceed" in str(err)
diff --git a/tests/providers/google/cloud/operators/test_cloud_sql_system.py b/tests/providers/google/cloud/operators/test_cloud_sql_system.py
index dbf5feed9360d..0d6e27400f3a8 100644
--- a/tests/providers/google/cloud/operators/test_cloud_sql_system.py
+++ b/tests/providers/google/cloud/operators/test_cloud_sql_system.py
@@ -114,15 +114,15 @@ def test_start_proxy_fail_no_parameters(self):
project_id=GCP_PROJECT_ID,
instance_specification='a',
)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
runner.start_proxy()
- err = cm.exception
- self.assertIn("The cloud_sql_proxy finished early", str(err))
- with self.assertRaises(AirflowException) as cm:
+ err = ctx.value
+ assert "The cloud_sql_proxy finished early" in str(err)
+ with pytest.raises(AirflowException) as ctx:
runner.start_proxy()
- err = cm.exception
- self.assertIn("The cloud_sql_proxy finished early", str(err))
- self.assertIsNone(runner.sql_proxy_process)
+ err = ctx.value
+ assert "The cloud_sql_proxy finished early" in str(err)
+ assert runner.sql_proxy_process is None
def test_start_proxy_with_all_instances(self):
runner = CloudSqlProxyRunner(
@@ -135,7 +135,7 @@ def test_start_proxy_with_all_instances(self):
time.sleep(1)
finally:
runner.stop_proxy()
- self.assertIsNone(runner.sql_proxy_process)
+ assert runner.sql_proxy_process is None
@provide_gcp_context(GCP_CLOUDSQL_KEY)
def test_start_proxy_with_all_instances_generated_credential_file(self):
@@ -149,7 +149,7 @@ def test_start_proxy_with_all_instances_generated_credential_file(self):
time.sleep(1)
finally:
runner.stop_proxy()
- self.assertIsNone(runner.sql_proxy_process)
+ assert runner.sql_proxy_process is None
def test_start_proxy_with_all_instances_specific_version(self):
runner = CloudSqlProxyRunner(
@@ -163,8 +163,8 @@ def test_start_proxy_with_all_instances_specific_version(self):
time.sleep(1)
finally:
runner.stop_proxy()
- self.assertIsNone(runner.sql_proxy_process)
- self.assertEqual(runner.get_proxy_version(), "1.13")
+ assert runner.sql_proxy_process is None
+ assert runner.get_proxy_version() == "1.13"
@provide_gcp_context(GCP_CLOUDSQL_KEY)
def test_run_example_dag_cloudsql_query(self):
diff --git a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
index 9dc3060b4c432..6425f6ceb9844 100644
--- a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
+++ b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
@@ -22,6 +22,7 @@
from typing import Dict
from unittest import mock
+import pytest
from botocore.credentials import Credentials
from freezegun import freeze_time
from parameterized import parameterized
@@ -150,7 +151,7 @@ class TestTransferJobPreprocessor(unittest.TestCase):
def test_should_do_nothing_on_empty(self):
body = {}
TransferJobPreprocessor(body=body).process_body()
- self.assertEqual(body, {})
+ assert body == {}
@unittest.skipIf(boto3 is None, "Skipping test because boto3 is not available")
@mock.patch('airflow.providers.google.cloud.operators.cloud_storage_transfer_service.AwsBaseHook')
@@ -161,63 +162,59 @@ def test_should_inject_aws_credentials(self, mock_hook):
body = {TRANSFER_SPEC: deepcopy(SOURCE_AWS)}
body = TransferJobPreprocessor(body=body).process_body()
- self.assertEqual(body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY], TEST_AWS_ACCESS_KEY)
+ assert body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] == TEST_AWS_ACCESS_KEY
@parameterized.expand([(SCHEDULE_START_DATE,), (SCHEDULE_END_DATE,)])
def test_should_format_date_from_python_to_dict(self, field_attr):
body = {SCHEDULE: {field_attr: NATIVE_DATE}}
TransferJobPreprocessor(body=body).process_body()
- self.assertEqual(body[SCHEDULE][field_attr], DICT_DATE)
+ assert body[SCHEDULE][field_attr] == DICT_DATE
def test_should_format_time_from_python_to_dict(self):
body = {SCHEDULE: {START_TIME_OF_DAY: NATIVE_TIME}}
TransferJobPreprocessor(body=body).process_body()
- self.assertEqual(body[SCHEDULE][START_TIME_OF_DAY], DICT_TIME)
+ assert body[SCHEDULE][START_TIME_OF_DAY] == DICT_TIME
@parameterized.expand([(SCHEDULE_START_DATE,), (SCHEDULE_END_DATE,)])
def test_should_not_change_date_for_dict(self, field_attr):
body = {SCHEDULE: {field_attr: DICT_DATE}}
TransferJobPreprocessor(body=body).process_body()
- self.assertEqual(body[SCHEDULE][field_attr], DICT_DATE)
+ assert body[SCHEDULE][field_attr] == DICT_DATE
def test_should_not_change_time_for_dict(self):
body = {SCHEDULE: {START_TIME_OF_DAY: DICT_TIME}}
TransferJobPreprocessor(body=body).process_body()
- self.assertEqual(body[SCHEDULE][START_TIME_OF_DAY], DICT_TIME)
+ assert body[SCHEDULE][START_TIME_OF_DAY] == DICT_TIME
@freeze_time("2018-10-15")
def test_should_set_default_schedule(self):
body = {}
TransferJobPreprocessor(body=body, default_schedule=True).process_body()
- self.assertEqual(
- body,
- {
- SCHEDULE: {
- SCHEDULE_END_DATE: {'day': 15, 'month': 10, 'year': 2018},
- SCHEDULE_START_DATE: {'day': 15, 'month': 10, 'year': 2018},
- }
- },
- )
+ assert body == {
+ SCHEDULE: {
+ SCHEDULE_END_DATE: {'day': 15, 'month': 10, 'year': 2018},
+ SCHEDULE_START_DATE: {'day': 15, 'month': 10, 'year': 2018},
+ }
+ }
class TestTransferJobValidator(unittest.TestCase):
def test_should_raise_exception_when_encounters_aws_credentials(self):
body = {"transferSpec": {"awsS3DataSource": {"awsAccessKey": TEST_AWS_ACCESS_KEY}}}
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
TransferJobValidator(body=body).validate_body()
- err = cm.exception
- self.assertIn(
+ err = ctx.value
+ assert (
"AWS credentials detected inside the body parameter (awsAccessKey). This is not allowed, please "
- "use Airflow connections to store credentials.",
- str(err),
+ "use Airflow connections to store credentials." in str(err)
)
def test_should_raise_exception_when_body_empty(self):
body = None
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
TransferJobValidator(body=body).validate_body()
- err = cm.exception
- self.assertIn("The required parameter 'body' is empty or None", str(err))
+ err = ctx.value
+ assert "The required parameter 'body' is empty or None" in str(err)
@parameterized.expand(
[
@@ -230,13 +227,12 @@ def test_should_raise_exception_when_body_empty(self):
def test_verify_data_source(self, transfer_spec):
body = {TRANSFER_SPEC: transfer_spec}
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
TransferJobValidator(body=body).validate_body()
- err = cm.exception
- self.assertIn(
+ err = ctx.value
+ assert (
"More than one data source detected. Please choose exactly one data source from: "
- "gcsDataSource, awsS3DataSource and httpDataSource.",
- str(err),
+ "gcsDataSource, awsS3DataSource and httpDataSource." in str(err)
)
@parameterized.expand([(VALID_TRANSFER_JOB_GCS,), (VALID_TRANSFER_JOB_AWS,)])
@@ -247,7 +243,7 @@ def test_verify_success(self, body):
except AirflowException:
validated = False
- self.assertTrue(validated)
+ assert validated
class TestGcpStorageTransferJobCreateOperator(unittest.TestCase):
@@ -273,7 +269,7 @@ def test_job_create_gcs(self, mock_hook):
mock_hook.return_value.create_transfer_job.assert_called_once_with(body=VALID_TRANSFER_JOB_GCS_RAW)
- self.assertEqual(result, VALID_TRANSFER_JOB_GCS_RAW)
+ assert result == VALID_TRANSFER_JOB_GCS_RAW
@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
@@ -302,7 +298,7 @@ def test_job_create_aws(self, aws_hook, mock_hook):
mock_hook.return_value.create_transfer_job.assert_called_once_with(body=VALID_TRANSFER_JOB_AWS_RAW)
- self.assertEqual(result, VALID_TRANSFER_JOB_AWS_RAW)
+ assert result == VALID_TRANSFER_JOB_AWS_RAW
@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
@@ -317,11 +313,11 @@ def test_job_create_multiple(self, aws_hook, gcp_hook):
op = CloudDataTransferServiceCreateJobOperator(body=body, task_id=TASK_ID)
result = op.execute(None)
- self.assertEqual(result, VALID_TRANSFER_JOB_AWS_RAW)
+ assert result == VALID_TRANSFER_JOB_AWS_RAW
op = CloudDataTransferServiceCreateJobOperator(body=body, task_id=TASK_ID)
result = op.execute(None)
- self.assertEqual(result, VALID_TRANSFER_JOB_AWS_RAW)
+ assert result == VALID_TRANSFER_JOB_AWS_RAW
# Setting all of the operator's input parameters as templated dag_ids
# (could be anything else) just to test if the templating works for all
@@ -342,9 +338,9 @@ def test_templates(self, _):
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
- self.assertEqual(dag_id, getattr(op, 'body')[DESCRIPTION])
- self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
- self.assertEqual(dag_id, getattr(op, 'aws_conn_id'))
+ assert dag_id == getattr(op, 'body')[DESCRIPTION]
+ assert dag_id == getattr(op, 'gcp_conn_id')
+ assert dag_id == getattr(op, 'aws_conn_id')
class TestGcpStorageTransferJobUpdateOperator(unittest.TestCase):
@@ -369,7 +365,7 @@ def test_job_update(self, mock_hook):
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.update_transfer_job.assert_called_once_with(job_name=JOB_NAME, body=body)
- self.assertEqual(result, VALID_TRANSFER_JOB_GCS)
+ assert result == VALID_TRANSFER_JOB_GCS
# Setting all of the operator's input parameters as templated dag_ids
# (could be anything else) just to test if the templating works for all
@@ -389,8 +385,8 @@ def test_templates(self, _):
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
- self.assertEqual(dag_id, getattr(op, 'body')['transferJob']['name'])
- self.assertEqual(dag_id, getattr(op, 'job_name'))
+ assert dag_id == getattr(op, 'body')['transferJob']['name']
+ assert dag_id == getattr(op, 'job_name')
class TestGcpStorageTransferJobDeleteOperator(unittest.TestCase):
@@ -433,19 +429,19 @@ def test_job_delete_with_templates(self, _):
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
- self.assertEqual(dag_id, getattr(op, 'job_name'))
- self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
- self.assertEqual(dag_id, getattr(op, 'api_version'))
+ assert dag_id == getattr(op, 'job_name')
+ assert dag_id == getattr(op, 'gcp_conn_id')
+ assert dag_id == getattr(op, 'api_version')
@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
)
def test_job_delete_should_throw_ex_when_name_none(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudDataTransferServiceDeleteJobOperator(job_name="", task_id='task-id')
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'job_name' is empty or None", str(err))
+ err = ctx.value
+ assert "The required parameter 'job_name' is empty or None" in str(err)
mock_hook.assert_not_called()
@@ -467,7 +463,7 @@ def test_operation_get(self, mock_hook):
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.get_transfer_operation.assert_called_once_with(operation_name=OPERATION_NAME)
- self.assertEqual(result, VALID_OPERATION)
+ assert result == VALID_OPERATION
# Setting all of the operator's input parameters as templated dag_ids
# (could be anything else) just to test if the templating works for all
@@ -484,17 +480,17 @@ def test_operation_get_with_templates(self, _):
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
- self.assertEqual(dag_id, getattr(op, 'operation_name'))
+ assert dag_id == getattr(op, 'operation_name')
@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
)
def test_operation_get_should_throw_ex_when_operation_name_none(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudDataTransferServiceGetOperationOperator(operation_name="", task_id=TASK_ID)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'operation_name' is empty or None", str(err))
+ err = ctx.value
+ assert "The required parameter 'operation_name' is empty or None" in str(err)
mock_hook.assert_not_called()
@@ -516,7 +512,7 @@ def test_operation_list(self, mock_hook):
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.list_transfer_operations.assert_called_once_with(request_filter=TEST_FILTER)
- self.assertEqual(result, [VALID_TRANSFER_JOB_GCS])
+ assert result == [VALID_TRANSFER_JOB_GCS]
# Setting all of the operator's input parameters as templated dag_ids
# (could be anything else) just to test if the templating works for all
@@ -538,10 +534,10 @@ def test_templates(self, _):
ti.render_templates()
# pylint: disable=unsubscriptable-object
- self.assertEqual(dag_id, getattr(op, 'filter')['job_names'][0])
+ assert dag_id == getattr(op, 'filter')['job_names'][0]
# pylint: enable=unsubscriptable-object
- self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
+ assert dag_id == getattr(op, 'gcp_conn_id')
class TestGcpStorageTransferOperationsPauseOperator(unittest.TestCase):
@@ -581,19 +577,19 @@ def test_operation_pause_with_templates(self, _):
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
- self.assertEqual(dag_id, getattr(op, 'operation_name'))
- self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
- self.assertEqual(dag_id, getattr(op, 'api_version'))
+ assert dag_id == getattr(op, 'operation_name')
+ assert dag_id == getattr(op, 'gcp_conn_id')
+ assert dag_id == getattr(op, 'api_version')
@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
)
def test_operation_pause_should_throw_ex_when_name_none(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudDataTransferServicePauseOperationOperator(operation_name="", task_id='task-id')
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'operation_name' is empty or None", str(err))
+ err = ctx.value
+ assert "The required parameter 'operation_name' is empty or None" in str(err)
mock_hook.assert_not_called()
@@ -616,7 +612,7 @@ def test_operation_resume(self, mock_hook):
mock_hook.return_value.resume_transfer_operation.assert_called_once_with(
operation_name=OPERATION_NAME
)
- self.assertIsNone(result)
+ assert result is None
# Setting all of the operator's input parameters as templated dag_ids
# (could be anything else) just to test if the templating works for all
@@ -637,19 +633,19 @@ def test_operation_resume_with_templates(self, _):
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
- self.assertEqual(dag_id, getattr(op, 'operation_name'))
- self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
- self.assertEqual(dag_id, getattr(op, 'api_version'))
+ assert dag_id == getattr(op, 'operation_name')
+ assert dag_id == getattr(op, 'gcp_conn_id')
+ assert dag_id == getattr(op, 'api_version')
@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
)
def test_operation_resume_should_throw_ex_when_name_none(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudDataTransferServiceResumeOperationOperator(operation_name="", task_id=TASK_ID)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'operation_name' is empty or None", str(err))
+ err = ctx.value
+ assert "The required parameter 'operation_name' is empty or None" in str(err)
mock_hook.assert_not_called()
@@ -672,7 +668,7 @@ def test_operation_cancel(self, mock_hook):
mock_hook.return_value.cancel_transfer_operation.assert_called_once_with(
operation_name=OPERATION_NAME
)
- self.assertIsNone(result)
+ assert result is None
# Setting all of the operator's input parameters as templated dag_ids
# (could be anything else) just to test if the templating works for all
@@ -693,19 +689,19 @@ def test_operation_cancel_with_templates(self, _):
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
- self.assertEqual(dag_id, getattr(op, 'operation_name'))
- self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
- self.assertEqual(dag_id, getattr(op, 'api_version'))
+ assert dag_id == getattr(op, 'operation_name')
+ assert dag_id == getattr(op, 'gcp_conn_id')
+ assert dag_id == getattr(op, 'api_version')
@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
)
def test_operation_cancel_should_throw_ex_when_name_none(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudDataTransferServiceCancelOperationOperator(operation_name="", task_id=TASK_ID)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'operation_name' is empty or None", str(err))
+ err = ctx.value
+ assert "The required parameter 'operation_name' is empty or None" in str(err)
mock_hook.assert_not_called()
@@ -720,12 +716,12 @@ def test_constructor(self):
schedule=SCHEDULE_DICT,
)
- self.assertEqual(operator.task_id, TASK_ID)
- self.assertEqual(operator.s3_bucket, AWS_BUCKET_NAME)
- self.assertEqual(operator.gcs_bucket, GCS_BUCKET_NAME)
- self.assertEqual(operator.project_id, GCP_PROJECT_ID)
- self.assertEqual(operator.description, DESCRIPTION)
- self.assertEqual(operator.schedule, SCHEDULE_DICT)
+ assert operator.task_id == TASK_ID
+ assert operator.s3_bucket == AWS_BUCKET_NAME
+ assert operator.gcs_bucket == GCS_BUCKET_NAME
+ assert operator.project_id == GCP_PROJECT_ID
+ assert operator.description == DESCRIPTION
+ assert operator.schedule == SCHEDULE_DICT
# Setting all of the operator's input parameters as templated dag_ids
# (could be anything else) just to test if the templating works for all
@@ -748,15 +744,15 @@ def test_templates(self, _):
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
- self.assertEqual(dag_id, getattr(op, 's3_bucket'))
- self.assertEqual(dag_id, getattr(op, 'gcs_bucket'))
- self.assertEqual(dag_id, getattr(op, 'description'))
+ assert dag_id == getattr(op, 's3_bucket')
+ assert dag_id == getattr(op, 'gcs_bucket')
+ assert dag_id == getattr(op, 'description')
# pylint: disable=unsubscriptable-object
- self.assertEqual(dag_id, getattr(op, 'object_conditions')['exclude_prefixes'][0])
+ assert dag_id == getattr(op, 'object_conditions')['exclude_prefixes'][0]
# pylint: enable=unsubscriptable-object
- self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
+ assert dag_id == getattr(op, 'gcp_conn_id')
@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
@@ -781,7 +777,7 @@ def test_execute(self, mock_aws_hook, mock_transfer_hook):
body=VALID_TRANSFER_JOB_AWS_RAW
)
- self.assertTrue(mock_transfer_hook.return_value.wait_for_transfer_job.called)
+ assert mock_transfer_hook.return_value.wait_for_transfer_job.called
@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
@@ -807,7 +803,7 @@ def test_execute_skip_wait(self, mock_aws_hook, mock_transfer_hook):
body=VALID_TRANSFER_JOB_AWS_RAW
)
- self.assertFalse(mock_transfer_hook.return_value.wait_for_transfer_job.called)
+ assert not mock_transfer_hook.return_value.wait_for_transfer_job.called
class TestGoogleCloudStorageToGoogleCloudStorageTransferOperator(unittest.TestCase):
@@ -821,12 +817,12 @@ def test_constructor(self):
schedule=SCHEDULE_DICT,
)
- self.assertEqual(operator.task_id, TASK_ID)
- self.assertEqual(operator.source_bucket, GCS_BUCKET_NAME)
- self.assertEqual(operator.destination_bucket, GCS_BUCKET_NAME)
- self.assertEqual(operator.project_id, GCP_PROJECT_ID)
- self.assertEqual(operator.description, DESCRIPTION)
- self.assertEqual(operator.schedule, SCHEDULE_DICT)
+ assert operator.task_id == TASK_ID
+ assert operator.source_bucket == GCS_BUCKET_NAME
+ assert operator.destination_bucket == GCS_BUCKET_NAME
+ assert operator.project_id == GCP_PROJECT_ID
+ assert operator.description == DESCRIPTION
+ assert operator.schedule == SCHEDULE_DICT
# Setting all of the operator's input parameters as templated dag_ids
# (could be anything else) just to test if the templating works for all
@@ -849,15 +845,15 @@ def test_templates(self, _):
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
- self.assertEqual(dag_id, getattr(op, 'source_bucket'))
- self.assertEqual(dag_id, getattr(op, 'destination_bucket'))
- self.assertEqual(dag_id, getattr(op, 'description'))
+ assert dag_id == getattr(op, 'source_bucket')
+ assert dag_id == getattr(op, 'destination_bucket')
+ assert dag_id == getattr(op, 'description')
# pylint: disable=unsubscriptable-object
- self.assertEqual(dag_id, getattr(op, 'object_conditions')['exclude_prefixes'][0])
+ assert dag_id == getattr(op, 'object_conditions')['exclude_prefixes'][0]
# pylint: enable=unsubscriptable-object
- self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
+ assert dag_id == getattr(op, 'gcp_conn_id')
@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
@@ -876,7 +872,7 @@ def test_execute(self, mock_transfer_hook):
mock_transfer_hook.return_value.create_transfer_job.assert_called_once_with(
body=VALID_TRANSFER_JOB_GCS_RAW
)
- self.assertTrue(mock_transfer_hook.return_value.wait_for_transfer_job.called)
+ assert mock_transfer_hook.return_value.wait_for_transfer_job.called
@mock.patch(
'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook'
@@ -896,4 +892,4 @@ def test_execute_skip_wait(self, mock_transfer_hook):
mock_transfer_hook.return_value.create_transfer_job.assert_called_once_with(
body=VALID_TRANSFER_JOB_GCS_RAW
)
- self.assertFalse(mock_transfer_hook.return_value.wait_for_transfer_job.called)
+ assert not mock_transfer_hook.return_value.wait_for_transfer_job.called
diff --git a/tests/providers/google/cloud/operators/test_compute.py b/tests/providers/google/cloud/operators/test_compute.py
index ccd0b0190b4d0..5359cd80376de 100644
--- a/tests/providers/google/cloud/operators/test_compute.py
+++ b/tests/providers/google/cloud/operators/test_compute.py
@@ -24,6 +24,7 @@
from unittest import mock
import httplib2
+import pytest
from googleapiclient.errors import HttpError
from airflow.exceptions import AirflowException
@@ -64,7 +65,7 @@ def test_instance_start(self, mock_hook):
mock_hook.return_value.start_instance.assert_called_once_with(
zone=GCE_ZONE, resource_id=RESOURCE_ID, project_id=GCP_PROJECT_ID
)
- self.assertTrue(result)
+ assert result
# Setting all of the operator's input parameters as template dag_ids
# (could be anything else) just to test if the templating works for all fields
@@ -84,21 +85,21 @@ def test_instance_start_with_templates(self, _):
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
- self.assertEqual(dag_id, getattr(op, 'project_id'))
- self.assertEqual(dag_id, getattr(op, 'zone'))
- self.assertEqual(dag_id, getattr(op, 'resource_id'))
- self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
- self.assertEqual(dag_id, getattr(op, 'api_version'))
+ assert dag_id == getattr(op, 'project_id')
+ assert dag_id == getattr(op, 'zone')
+ assert dag_id == getattr(op, 'resource_id')
+ assert dag_id == getattr(op, 'gcp_conn_id')
+ assert dag_id == getattr(op, 'api_version')
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_start_should_throw_ex_when_missing_project_id(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = ComputeEngineStartInstanceOperator(
project_id="", zone=GCE_ZONE, resource_id=RESOURCE_ID, task_id='id'
)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'project_id' is missing", str(err))
+ err = ctx.value
+ assert "The required parameter 'project_id' is missing" in str(err)
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
@@ -108,24 +109,24 @@ def test_start_should_not_throw_ex_when_project_id_none(self, _):
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_start_should_throw_ex_when_missing_zone(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = ComputeEngineStartInstanceOperator(
project_id=GCP_PROJECT_ID, zone="", resource_id=RESOURCE_ID, task_id='id'
)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'zone' is missing", str(err))
+ err = ctx.value
+ assert "The required parameter 'zone' is missing" in str(err)
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_start_should_throw_ex_when_missing_resource_id(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = ComputeEngineStartInstanceOperator(
project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id="", task_id='id'
)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'resource_id' is missing", str(err))
+ err = ctx.value
+ assert "The required parameter 'resource_id' is missing" in str(err)
mock_hook.assert_not_called()
@@ -163,21 +164,21 @@ def test_instance_stop_with_templates(self, _):
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
- self.assertEqual(dag_id, getattr(op, 'project_id'))
- self.assertEqual(dag_id, getattr(op, 'zone'))
- self.assertEqual(dag_id, getattr(op, 'resource_id'))
- self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
- self.assertEqual(dag_id, getattr(op, 'api_version'))
+ assert dag_id == getattr(op, 'project_id')
+ assert dag_id == getattr(op, 'zone')
+ assert dag_id == getattr(op, 'resource_id')
+ assert dag_id == getattr(op, 'gcp_conn_id')
+ assert dag_id == getattr(op, 'api_version')
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_stop_should_throw_ex_when_missing_project_id(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = ComputeEngineStopInstanceOperator(
project_id="", zone=GCE_ZONE, resource_id=RESOURCE_ID, task_id='id'
)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'project_id' is missing", str(err))
+ err = ctx.value
+ assert "The required parameter 'project_id' is missing" in str(err)
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
@@ -195,24 +196,24 @@ def test_stop_should_not_throw_ex_when_project_id_none(self, mock_hook):
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_stop_should_throw_ex_when_missing_zone(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = ComputeEngineStopInstanceOperator(
project_id=GCP_PROJECT_ID, zone="", resource_id=RESOURCE_ID, task_id='id'
)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'zone' is missing", str(err))
+ err = ctx.value
+ assert "The required parameter 'zone' is missing" in str(err)
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_stop_should_throw_ex_when_missing_resource_id(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = ComputeEngineStopInstanceOperator(
project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id="", task_id='id'
)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'resource_id' is missing", str(err))
+ err = ctx.value
+ assert "The required parameter 'resource_id' is missing" in str(err)
mock_hook.assert_not_called()
@@ -256,15 +257,15 @@ def test_set_machine_type_with_templates(self, _):
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
- self.assertEqual(dag_id, getattr(op, 'project_id'))
- self.assertEqual(dag_id, getattr(op, 'zone'))
- self.assertEqual(dag_id, getattr(op, 'resource_id'))
- self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
- self.assertEqual(dag_id, getattr(op, 'api_version'))
+ assert dag_id == getattr(op, 'project_id')
+ assert dag_id == getattr(op, 'zone')
+ assert dag_id == getattr(op, 'resource_id')
+ assert dag_id == getattr(op, 'gcp_conn_id')
+ assert dag_id == getattr(op, 'api_version')
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_set_machine_type_should_throw_ex_when_missing_project_id(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = ComputeEngineSetMachineTypeOperator(
project_id="",
zone=GCE_ZONE,
@@ -273,8 +274,8 @@ def test_set_machine_type_should_throw_ex_when_missing_project_id(self, mock_hoo
task_id='id',
)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'project_id' is missing", str(err))
+ err = ctx.value
+ assert "The required parameter 'project_id' is missing" in str(err)
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
@@ -294,7 +295,7 @@ def test_set_machine_type_should_not_throw_ex_when_project_id_none(self, mock_ho
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_set_machine_type_should_throw_ex_when_missing_zone(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = ComputeEngineSetMachineTypeOperator(
project_id=GCP_PROJECT_ID,
zone="",
@@ -303,13 +304,13 @@ def test_set_machine_type_should_throw_ex_when_missing_zone(self, mock_hook):
task_id='id',
)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'zone' is missing", str(err))
+ err = ctx.value
+ assert "The required parameter 'zone' is missing" in str(err)
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_set_machine_type_should_throw_ex_when_missing_resource_id(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = ComputeEngineSetMachineTypeOperator(
project_id=GCP_PROJECT_ID,
zone=GCE_ZONE,
@@ -318,19 +319,19 @@ def test_set_machine_type_should_throw_ex_when_missing_resource_id(self, mock_ho
task_id='id',
)
op.execute(None)
- err = cm.exception
- self.assertIn("The required parameter 'resource_id' is missing", str(err))
+ err = ctx.value
+ assert "The required parameter 'resource_id' is missing" in str(err)
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_set_machine_type_should_throw_ex_when_missing_machine_type(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = ComputeEngineSetMachineTypeOperator(
project_id=GCP_PROJECT_ID, zone=GCE_ZONE, resource_id=RESOURCE_ID, body={}, task_id='id'
)
op.execute(None)
- err = cm.exception
- self.assertIn("The required body field 'machineType' is missing. Please add it.", str(err))
+ err = ctx.value
+ assert "The required body field 'machineType' is missing. Please add it." in str(err)
mock_hook.assert_called_once_with(
api_version='v1',
gcp_conn_id='google_cloud_default',
@@ -373,7 +374,7 @@ def test_set_machine_type_should_handle_and_trim_gce_error(
get_conn.return_value = {}
_execute_set_machine_type.return_value = {"name": "test-operation"}
_check_zone_operation_status.return_value = ast.literal_eval(self.MOCK_OP_RESPONSE)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = ComputeEngineSetMachineTypeOperator(
project_id=GCP_PROJECT_ID,
zone=GCE_ZONE,
@@ -382,7 +383,7 @@ def test_set_machine_type_should_handle_and_trim_gce_error(
task_id='id',
)
op.execute(None)
- err = cm.exception
+ err = ctx.value
_check_zone_operation_status.assert_called_once_with(
{}, "test-operation", GCP_PROJECT_ID, GCE_ZONE, mock.ANY
)
@@ -391,8 +392,8 @@ def test_set_machine_type_should_handle_and_trim_gce_error(
)
# Checking the full message was sometimes failing due to different order
# of keys in the serialized JSON
- self.assertIn("400 BAD REQUEST: {", str(err)) # checking the square bracket trim
- self.assertIn("UNSUPPORTED_OPERATION", str(err))
+ assert "400 BAD REQUEST: {" in str(err) # checking the square bracket trim
+ assert "UNSUPPORTED_OPERATION" in str(err)
GCE_INSTANCE_TEMPLATE_NAME = "instance-template-test"
@@ -497,7 +498,7 @@ def test_successful_copy_template(self, mock_hook):
mock_hook.return_value.insert_instance_template.assert_called_once_with(
project_id=GCP_PROJECT_ID, body=GCE_INSTANCE_TEMPLATE_BODY_INSERT, request_id=None
)
- self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result)
+ assert GCE_INSTANCE_TEMPLATE_BODY_GET_NEW == result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_successful_copy_template_missing_project_id(self, mock_hook):
@@ -520,7 +521,7 @@ def test_successful_copy_template_missing_project_id(self, mock_hook):
mock_hook.return_value.insert_instance_template.assert_called_once_with(
project_id=None, body=GCE_INSTANCE_TEMPLATE_BODY_INSERT, request_id=None
)
- self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result)
+ assert GCE_INSTANCE_TEMPLATE_BODY_GET_NEW == result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_idempotent_copy_template_when_already_copied(self, mock_hook):
@@ -538,7 +539,7 @@ def test_idempotent_copy_template_when_already_copied(self, mock_hook):
impersonation_chain=None,
)
mock_hook.return_value.insert_instance_template.assert_not_called()
- self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result)
+ assert GCE_INSTANCE_TEMPLATE_BODY_GET_NEW == result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_successful_copy_template_with_request_id(self, mock_hook):
@@ -565,7 +566,7 @@ def test_successful_copy_template_with_request_id(self, mock_hook):
body=GCE_INSTANCE_TEMPLATE_BODY_INSERT,
request_id=GCE_INSTANCE_TEMPLATE_REQUEST_ID,
)
- self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result)
+ assert GCE_INSTANCE_TEMPLATE_BODY_GET_NEW == result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_successful_copy_template_with_description_fields(self, mock_hook):
@@ -595,7 +596,7 @@ def test_successful_copy_template_with_description_fields(self, mock_hook):
body=body_insert,
request_id=GCE_INSTANCE_TEMPLATE_REQUEST_ID,
)
- self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result)
+ assert GCE_INSTANCE_TEMPLATE_BODY_GET_NEW == result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_copy_with_some_validation_warnings(self, mock_hook):
@@ -628,7 +629,7 @@ def test_copy_with_some_validation_warnings(self, mock_hook):
body=body_insert,
request_id=None,
)
- self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result)
+ assert GCE_INSTANCE_TEMPLATE_BODY_GET_NEW == result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_successful_copy_template_with_updated_nested_fields(self, mock_hook):
@@ -659,7 +660,7 @@ def test_successful_copy_template_with_updated_nested_fields(self, mock_hook):
mock_hook.return_value.insert_instance_template.assert_called_once_with(
project_id=GCP_PROJECT_ID, body=body_insert, request_id=None
)
- self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result)
+ assert GCE_INSTANCE_TEMPLATE_BODY_GET_NEW == result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_successful_copy_template_with_smaller_array_fields(self, mock_hook):
@@ -703,7 +704,7 @@ def test_successful_copy_template_with_smaller_array_fields(self, mock_hook):
mock_hook.return_value.insert_instance_template.assert_called_once_with(
project_id=GCP_PROJECT_ID, body=body_insert, request_id=None
)
- self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result)
+ assert GCE_INSTANCE_TEMPLATE_BODY_GET_NEW == result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_successful_copy_template_with_bigger_array_fields(self, mock_hook):
@@ -767,7 +768,7 @@ def test_successful_copy_template_with_bigger_array_fields(self, mock_hook):
body=body_insert,
request_id=None,
)
- self.assertEqual(GCE_INSTANCE_TEMPLATE_BODY_GET_NEW, result)
+ assert GCE_INSTANCE_TEMPLATE_BODY_GET_NEW == result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_missing_name(self, mock_hook):
@@ -776,7 +777,7 @@ def test_missing_name(self, mock_hook):
GCE_INSTANCE_TEMPLATE_BODY_GET,
GCE_INSTANCE_TEMPLATE_BODY_GET_NEW,
]
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = ComputeEngineCopyInstanceTemplateOperator(
project_id=GCP_PROJECT_ID,
resource_id=GCE_INSTANCE_TEMPLATE_NAME,
@@ -785,8 +786,8 @@ def test_missing_name(self, mock_hook):
body_patch={"description": "New description"},
)
op.execute(None)
- err = cm.exception
- self.assertIn("should contain at least name for the new operator in the 'name' field", str(err))
+ err = ctx.value
+ assert "should contain at least name for the new operator in the 'name' field" in str(err)
mock_hook.assert_not_called()
@@ -900,7 +901,7 @@ def test_successful_instance_group_update(self, mock_hook):
body=GCE_INSTANCE_GROUP_MANAGER_EXPECTED_PATCH,
request_id=None,
)
- self.assertTrue(result)
+ assert result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_successful_instance_group_update_missing_project_id(self, mock_hook):
@@ -927,7 +928,7 @@ def test_successful_instance_group_update_missing_project_id(self, mock_hook):
body=GCE_INSTANCE_GROUP_MANAGER_EXPECTED_PATCH,
request_id=None,
)
- self.assertTrue(result)
+ assert result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_successful_instance_group_update_no_instance_template_field(self, mock_hook):
@@ -957,7 +958,7 @@ def test_successful_instance_group_update_no_instance_template_field(self, mock_
body=expected_patch_no_instance_template,
request_id=None,
)
- self.assertTrue(result)
+ assert result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_successful_instance_group_update_no_versions_field(self, mock_hook):
@@ -987,7 +988,7 @@ def test_successful_instance_group_update_no_versions_field(self, mock_hook):
body=expected_patch_no_versions,
request_id=None,
)
- self.assertTrue(result)
+ assert result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_successful_instance_group_update_with_update_policy(self, mock_hook):
@@ -1018,7 +1019,7 @@ def test_successful_instance_group_update_with_update_policy(self, mock_hook):
body=expected_patch_with_update_policy,
request_id=None,
)
- self.assertTrue(result)
+ assert result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_successful_instance_group_update_with_request_id(self, mock_hook):
@@ -1047,11 +1048,11 @@ def test_successful_instance_group_update_with_request_id(self, mock_hook):
body=GCE_INSTANCE_GROUP_MANAGER_EXPECTED_PATCH,
request_id=GCE_INSTANCE_GROUP_MANAGER_REQUEST_ID,
)
- self.assertTrue(result)
+ assert result
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_try_to_use_api_v1(self, _):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
ComputeEngineInstanceGroupUpdateManagerTemplateOperator(
project_id=GCP_PROJECT_ID,
zone=GCE_ZONE,
@@ -1061,8 +1062,8 @@ def test_try_to_use_api_v1(self, _):
source_template=GCE_INSTANCE_TEMPLATE_SOURCE_URL,
destination_template=GCE_INSTANCE_TEMPLATE_DESTINATION_URL,
)
- err = cm.exception
- self.assertIn("Use beta api version or above", str(err))
+ err = ctx.value
+ assert "Use beta api version or above" in str(err)
@mock.patch('airflow.providers.google.cloud.operators.compute.ComputeEngineHook')
def test_try_to_use_non_existing_template(self, mock_hook):
@@ -1084,4 +1085,4 @@ def test_try_to_use_non_existing_template(self, mock_hook):
impersonation_chain=None,
)
mock_hook.return_value.patch_instance_group_manager.assert_not_called()
- self.assertTrue(result)
+ assert result
diff --git a/tests/providers/google/cloud/operators/test_datacatalog.py b/tests/providers/google/cloud/operators/test_datacatalog.py
index 7d17fac0cb026..517b35c71edc1 100644
--- a/tests/providers/google/cloud/operators/test_datacatalog.py
+++ b/tests/providers/google/cloud/operators/test_datacatalog.py
@@ -144,7 +144,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None:
metadata=TEST_METADATA,
)
ti.xcom_push.assert_called_once_with(key="entry_id", value=TEST_ENTRY_ID)
- self.assertEqual(TEST_ENTRY_DICT, result)
+ assert TEST_ENTRY_DICT == result
@mock.patch(
"airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
@@ -193,7 +193,7 @@ def test_assert_valid_hook_call_when_exists(self, mock_hook) -> None:
metadata=TEST_METADATA,
)
ti.xcom_push.assert_called_once_with(key="entry_id", value=TEST_ENTRY_ID)
- self.assertEqual(TEST_ENTRY_DICT, result)
+ assert TEST_ENTRY_DICT == result
class TestCloudDataCatalogCreateEntryGroupOperator(TestCase):
@@ -230,7 +230,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None:
metadata=TEST_METADATA,
)
ti.xcom_push.assert_called_once_with(key="entry_group_id", value=TEST_ENTRY_GROUP_ID)
- self.assertEqual(result, TEST_ENTRY_GROUP_DICT)
+ assert result == TEST_ENTRY_GROUP_DICT
class TestCloudDataCatalogCreateTagOperator(TestCase):
@@ -271,7 +271,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None:
metadata=TEST_METADATA,
)
ti.xcom_push.assert_called_once_with(key="tag_id", value=TEST_TAG_ID)
- self.assertEqual(TEST_TAG_DICT, result)
+ assert TEST_TAG_DICT == result
class TestCloudDataCatalogCreateTagTemplateOperator(TestCase):
@@ -308,7 +308,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None:
metadata=TEST_METADATA,
)
ti.xcom_push.assert_called_once_with(key="tag_template_id", value=TEST_TAG_TEMPLATE_ID)
- self.assertEqual(TEST_TAG_TEMPLATE_DICT, result)
+ assert TEST_TAG_TEMPLATE_DICT == result
class TestCloudDataCatalogCreateTagTemplateFieldOperator(TestCase):
@@ -347,7 +347,7 @@ def test_assert_valid_hook_call(self, mock_hook) -> None:
metadata=TEST_METADATA,
)
ti.xcom_push.assert_called_once_with(key="tag_template_field_id", value=TEST_TAG_TEMPLATE_FIELD_ID)
- self.assertEqual(TEST_TAG_TEMPLATE_FIELD_DICT, result)
+ assert TEST_TAG_TEMPLATE_FIELD_DICT == result
class TestCloudDataCatalogDeleteEntryOperator(TestCase):
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py
index 9cb7490990e0c..7e290d7f05ca7 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -106,14 +106,14 @@ def setUp(self):
def test_init(self):
"""Test DataFlowPythonOperator instance is properly initialized."""
- self.assertEqual(self.dataflow.task_id, TASK_ID)
- self.assertEqual(self.dataflow.job_name, JOB_NAME)
- self.assertEqual(self.dataflow.py_file, PY_FILE)
- self.assertEqual(self.dataflow.py_options, PY_OPTIONS)
- self.assertEqual(self.dataflow.py_interpreter, PY_INTERPRETER)
- self.assertEqual(self.dataflow.poll_sleep, POLL_SLEEP)
- self.assertEqual(self.dataflow.dataflow_default_options, DEFAULT_OPTIONS_PYTHON)
- self.assertEqual(self.dataflow.options, EXPECTED_ADDITIONAL_OPTIONS)
+ assert self.dataflow.task_id == TASK_ID
+ assert self.dataflow.job_name == JOB_NAME
+ assert self.dataflow.py_file == PY_FILE
+ assert self.dataflow.py_options == PY_OPTIONS
+ assert self.dataflow.py_interpreter == PY_INTERPRETER
+ assert self.dataflow.poll_sleep == POLL_SLEEP
+ assert self.dataflow.dataflow_default_options == DEFAULT_OPTIONS_PYTHON
+ assert self.dataflow.options == EXPECTED_ADDITIONAL_OPTIONS
@mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
@@ -125,7 +125,7 @@ def test_exec(self, gcs_hook, dataflow_mock):
start_python_hook = dataflow_mock.return_value.start_python_dataflow
gcs_provide_file = gcs_hook.return_value.provide_file
self.dataflow.execute(None)
- self.assertTrue(dataflow_mock.called)
+ assert dataflow_mock.called
expected_options = {
'project': 'test',
'staging_location': 'gs://test/staging',
@@ -145,7 +145,7 @@ def test_exec(self, gcs_hook, dataflow_mock):
project_id=None,
location=TEST_LOCATION,
)
- self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow'))
+ assert self.dataflow.py_file.startswith('/tmp/dataflow')
class TestDataflowJavaOperator(unittest.TestCase):
@@ -163,14 +163,14 @@ def setUp(self):
def test_init(self):
"""Test DataflowTemplateOperator instance is properly initialized."""
- self.assertEqual(self.dataflow.task_id, TASK_ID)
- self.assertEqual(self.dataflow.job_name, JOB_NAME)
- self.assertEqual(self.dataflow.poll_sleep, POLL_SLEEP)
- self.assertEqual(self.dataflow.dataflow_default_options, DEFAULT_OPTIONS_JAVA)
- self.assertEqual(self.dataflow.job_class, JOB_CLASS)
- self.assertEqual(self.dataflow.jar, JAR_FILE)
- self.assertEqual(self.dataflow.options, EXPECTED_ADDITIONAL_OPTIONS)
- self.assertEqual(self.dataflow.check_if_running, CheckJobRunning.WaitForRun)
+ assert self.dataflow.task_id == TASK_ID
+ assert self.dataflow.job_name == JOB_NAME
+ assert self.dataflow.poll_sleep == POLL_SLEEP
+ assert self.dataflow.dataflow_default_options == DEFAULT_OPTIONS_JAVA
+ assert self.dataflow.job_class == JOB_CLASS
+ assert self.dataflow.jar == JAR_FILE
+ assert self.dataflow.options == EXPECTED_ADDITIONAL_OPTIONS
+ assert self.dataflow.check_if_running == CheckJobRunning.WaitForRun
@mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
@@ -183,7 +183,7 @@ def test_exec(self, gcs_hook, dataflow_mock):
gcs_provide_file = gcs_hook.return_value.provide_file
self.dataflow.check_if_running = CheckJobRunning.IgnoreJob
self.dataflow.execute(None)
- self.assertTrue(dataflow_mock.called)
+ assert dataflow_mock.called
gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
start_java_hook.assert_called_once_with(
job_name=JOB_NAME,
@@ -210,7 +210,7 @@ def test_check_job_running_exec(self, gcs_hook, dataflow_mock):
gcs_provide_file = gcs_hook.return_value.provide_file
self.dataflow.check_if_running = True
self.dataflow.execute(None)
- self.assertTrue(dataflow_mock.called)
+ assert dataflow_mock.called
gcs_provide_file.assert_not_called()
start_java_hook.assert_not_called()
dataflow_running.assert_called_once_with(
@@ -230,7 +230,7 @@ def test_check_job_not_running_exec(self, gcs_hook, dataflow_mock):
gcs_provide_file = gcs_hook.return_value.provide_file
self.dataflow.check_if_running = True
self.dataflow.execute(None)
- self.assertTrue(dataflow_mock.called)
+ assert dataflow_mock.called
gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
start_java_hook.assert_called_once_with(
job_name=JOB_NAME,
@@ -261,7 +261,7 @@ def test_check_multiple_job_exec(self, gcs_hook, dataflow_mock):
self.dataflow.multiple_jobs = True
self.dataflow.check_if_running = True
self.dataflow.execute(None)
- self.assertTrue(dataflow_mock.called)
+ assert dataflow_mock.called
gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
start_java_hook.assert_called_once_with(
job_name=JOB_NAME,
@@ -301,7 +301,7 @@ def test_exec(self, dataflow_mock):
"""
start_template_hook = dataflow_mock.return_value.start_template_dataflow
self.dataflow.execute(None)
- self.assertTrue(dataflow_mock.called)
+ assert dataflow_mock.called
expected_options = {
'project': 'test',
'stagingLocation': 'gs://test/staging',
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py
index a9752e3251d23..8c06ef7c59312 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -18,9 +18,9 @@
import inspect
import unittest
from datetime import datetime
-from typing import Any
from unittest import mock
+import pytest
from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.retry import Retry
@@ -129,27 +129,27 @@
}
-def assert_warning(msg: str, warning: Any):
- assert any(msg in str(w) for w in warning.warnings)
+def assert_warning(msg: str, warnings):
+ assert any(msg in str(w) for w in warnings)
class TestsClusterGenerator(unittest.TestCase):
def test_image_version(self):
- with self.assertRaises(ValueError) as err:
+ with pytest.raises(ValueError) as ctx:
ClusterGenerator(
custom_image="custom_image",
image_version="image_version",
project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME,
)
- self.assertIn("custom_image and image_version", str(err))
+ assert "custom_image and image_version" in str(ctx.value)
def test_nodes_number(self):
- with self.assertRaises(AssertionError) as err:
+ with pytest.raises(AssertionError) as ctx:
ClusterGenerator(
num_workers=0, num_preemptible_workers=0, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME
)
- self.assertIn("num_workers == 0 means single", str(err))
+ assert "num_workers == 0 means single" in str(ctx.value)
def test_build(self):
generator = ClusterGenerator(
@@ -186,12 +186,12 @@ def test_build(self):
customer_managed_key="customer_managed_key",
)
cluster = generator.make()
- self.assertDictEqual(CONFIG, cluster)
+ assert CONFIG == cluster
class TestDataprocClusterCreateOperator(unittest.TestCase):
def test_deprecation_warning(self):
- with self.assertWarns(DeprecationWarning) as warning:
+ with pytest.warns(DeprecationWarning) as warnings:
op = DataprocCreateClusterOperator(
task_id=TASK_ID,
region=GCP_LOCATION,
@@ -200,22 +200,22 @@ def test_deprecation_warning(self):
num_workers=2,
zone="zone",
)
- assert_warning("Passing cluster parameters by keywords", warning)
+ assert_warning("Passing cluster parameters by keywords", warnings)
- self.assertEqual(op.project_id, GCP_PROJECT)
- self.assertEqual(op.cluster_name, "cluster_name")
- self.assertEqual(op.cluster_config['worker_config']['num_instances'], 2)
- self.assertIn("zones/zone", op.cluster_config['master_config']["machine_type_uri"])
+ assert op.project_id == GCP_PROJECT
+ assert op.cluster_name == "cluster_name"
+ assert op.cluster_config['worker_config']['num_instances'] == 2
+ assert "zones/zone" in op.cluster_config['master_config']["machine_type_uri"]
- with self.assertWarns(DeprecationWarning) as warning:
+ with pytest.warns(DeprecationWarning) as warnings:
op_default_region = DataprocCreateClusterOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
cluster_name="cluster_name",
cluster_config=op.cluster_config,
)
- assert_warning("Default region value", warning)
- self.assertEqual(op_default_region.region, 'global')
+ assert_warning("Default region value", warnings)
+ assert op_default_region.region == 'global'
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook):
@@ -305,7 +305,7 @@ def test_execute_if_cluster_exists_do_not_use(self, mock_hook):
request_id=REQUEST_ID,
use_if_exists=False,
)
- with self.assertRaises(AlreadyExists):
+ with pytest.raises(AlreadyExists):
op.execute(context={})
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -329,7 +329,7 @@ def test_execute_if_cluster_exists_in_error_state(self, mock_hook):
metadata=METADATA,
request_id=REQUEST_ID,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
op.execute(context={})
mock_hook.return_value.diagnose_cluster.assert_called_once_with(
@@ -368,7 +368,7 @@ def test_execute_if_cluster_exists_in_deleting_state(
delete_on_error=True,
gcp_conn_id=GCP_CONN_ID,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
op.execute(context={})
calls = [mock.call(mock_hook.return_value), mock.call(mock_hook.return_value)]
@@ -381,9 +381,9 @@ def test_execute_if_cluster_exists_in_deleting_state(
class TestDataprocClusterScaleOperator(unittest.TestCase):
def test_deprecation_warning(self):
- with self.assertWarns(DeprecationWarning) as warning:
+ with pytest.warns(DeprecationWarning) as warnings:
DataprocScaleClusterOperator(task_id=TASK_ID, cluster_name=CLUSTER_NAME, project_id=GCP_PROJECT)
- assert_warning("DataprocUpdateClusterOperator", warning)
+ assert_warning("DataprocUpdateClusterOperator", warnings)
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook):
@@ -662,9 +662,9 @@ class TestDataProcHiveOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_deprecation_warning(self, mock_hook):
- with self.assertWarns(DeprecationWarning) as warning:
+ with pytest.warns(DeprecationWarning) as warnings:
DataprocSubmitHiveJobOperator(task_id=TASK_ID, region=GCP_LOCATION, query="query")
- assert_warning("DataprocSubmitJobOperator", warning)
+ assert_warning("DataprocSubmitJobOperator", warnings)
@mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -705,7 +705,7 @@ def test_builder(self, mock_hook, mock_uuid):
variables=self.variables,
)
job = op.generate_job()
- self.assertDictEqual(self.job, job)
+ assert self.job == job
class TestDataProcPigOperator(unittest.TestCase):
@@ -721,9 +721,9 @@ class TestDataProcPigOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_deprecation_warning(self, mock_hook):
- with self.assertWarns(DeprecationWarning) as warning:
+ with pytest.warns(DeprecationWarning) as warnings:
DataprocSubmitPigJobOperator(task_id=TASK_ID, region=GCP_LOCATION, query="query")
- assert_warning("DataprocSubmitJobOperator", warning)
+ assert_warning("DataprocSubmitJobOperator", warnings)
@mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -764,7 +764,7 @@ def test_builder(self, mock_hook, mock_uuid):
variables=self.variables,
)
job = op.generate_job()
- self.assertDictEqual(self.job, job)
+ assert self.job == job
class TestDataProcSparkSqlOperator(unittest.TestCase):
@@ -780,9 +780,9 @@ class TestDataProcSparkSqlOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_deprecation_warning(self, mock_hook):
- with self.assertWarns(DeprecationWarning) as warning:
+ with pytest.warns(DeprecationWarning) as warnings:
DataprocSubmitSparkSqlJobOperator(task_id=TASK_ID, region=GCP_LOCATION, query="query")
- assert_warning("DataprocSubmitJobOperator", warning)
+ assert_warning("DataprocSubmitJobOperator", warnings)
@mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -823,7 +823,7 @@ def test_builder(self, mock_hook, mock_uuid):
variables=self.variables,
)
job = op.generate_job()
- self.assertDictEqual(self.job, job)
+ assert self.job == job
class TestDataProcSparkOperator(unittest.TestCase):
@@ -839,11 +839,11 @@ class TestDataProcSparkOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_deprecation_warning(self, mock_hook):
- with self.assertWarns(DeprecationWarning) as warning:
+ with pytest.warns(DeprecationWarning) as warnings:
DataprocSubmitSparkJobOperator(
task_id=TASK_ID, region=GCP_LOCATION, main_class=self.main_class, dataproc_jars=self.jars
)
- assert_warning("DataprocSubmitJobOperator", warning)
+ assert_warning("DataprocSubmitJobOperator", warnings)
@mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -860,7 +860,7 @@ def test_execute(self, mock_hook, mock_uuid):
dataproc_jars=self.jars,
)
job = op.generate_job()
- self.assertDictEqual(self.job, job)
+ assert self.job == job
class TestDataProcHadoopOperator(unittest.TestCase):
@@ -876,11 +876,11 @@ class TestDataProcHadoopOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_deprecation_warning(self, mock_hook):
- with self.assertWarns(DeprecationWarning) as warning:
+ with pytest.warns(DeprecationWarning) as warnings:
DataprocSubmitHadoopJobOperator(
task_id=TASK_ID, region=GCP_LOCATION, main_jar=self.jar, arguments=self.args
)
- assert_warning("DataprocSubmitJobOperator", warning)
+ assert_warning("DataprocSubmitJobOperator", warnings)
@mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -897,7 +897,7 @@ def test_execute(self, mock_hook, mock_uuid):
arguments=self.args,
)
job = op.generate_job()
- self.assertDictEqual(self.job, job)
+ assert self.job == job
class TestDataProcPySparkOperator(unittest.TestCase):
@@ -912,9 +912,9 @@ class TestDataProcPySparkOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_deprecation_warning(self, mock_hook):
- with self.assertWarns(DeprecationWarning) as warning:
+ with pytest.warns(DeprecationWarning) as warnings:
DataprocSubmitPySparkJobOperator(task_id=TASK_ID, region=GCP_LOCATION, main=self.uri)
- assert_warning("DataprocSubmitJobOperator", warning)
+ assert_warning("DataprocSubmitJobOperator", warnings)
@mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -926,7 +926,7 @@ def test_execute(self, mock_hook, mock_uuid):
task_id=TASK_ID, region=GCP_LOCATION, gcp_conn_id=GCP_CONN_ID, main=self.uri
)
job = op.generate_job()
- self.assertDictEqual(self.job, job)
+ assert self.job == job
class TestDataprocCreateWorkflowTemplateOperator:
diff --git a/tests/providers/google/cloud/operators/test_functions.py b/tests/providers/google/cloud/operators/test_functions.py
index c56b261965fb8..5fc68de9e730f 100644
--- a/tests/providers/google/cloud/operators/test_functions.py
+++ b/tests/providers/google/cloud/operators/test_functions.py
@@ -20,6 +20,7 @@
from copy import deepcopy
from unittest import mock
+import pytest
from googleapiclient.errors import HttpError
from parameterized import parameterized
@@ -72,13 +73,13 @@ class TestGcfFunctionDeploy(unittest.TestCase):
@mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook')
def test_body_empty_or_missing_fields(self, body, message, mock_hook):
mock_hook.return_value.upload_function_zip.return_value = 'https://uploadUrl'
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudFunctionDeployFunctionOperator(
project_id="test_project_id", location="test_region", body=body, task_id="id"
)
op.execute(None)
- err = cm.exception
- self.assertIn(message, str(err))
+ err = ctx.value
+ assert message in str(err)
@mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook')
def test_deploy_execute(self, mock_hook):
@@ -149,21 +150,21 @@ def test_empty_project_id_is_ok(self, mock_hook):
@mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook')
def test_empty_location(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
CloudFunctionDeployFunctionOperator(
project_id="test_project_id", location="", body=None, task_id="id"
)
- err = cm.exception
- self.assertIn("The required parameter 'location' is missing", str(err))
+ err = ctx.value
+ assert "The required parameter 'location' is missing" in str(err)
@mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook')
def test_empty_body(self, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
CloudFunctionDeployFunctionOperator(
project_id="test_project_id", location="test_region", body=None, task_id="id"
)
- err = cm.exception
- self.assertIn("The required parameter 'body' is missing", str(err))
+ err = ctx.value
+ assert "The required parameter 'body' is missing" in str(err)
@parameterized.expand([(runtime,) for runtime in VALID_RUNTIMES])
@mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook')
@@ -255,13 +256,13 @@ def test_body_validation_simple(self, mock_hook):
mock_hook.return_value.create_new_function.return_value = True
body = deepcopy(VALID_BODY)
body['name'] = ''
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudFunctionDeployFunctionOperator(
project_id="test_project_id", location="test_region", body=body, task_id="id"
)
op.execute(None)
- err = cm.exception
- self.assertIn("The body field 'name' of value '' does not match", str(err))
+ err = ctx.value
+ assert "The body field 'name' of value '' does not match" in str(err)
mock_hook.assert_called_once_with(
api_version='v1',
gcp_conn_id='google_cloud_default',
@@ -288,13 +289,13 @@ def test_invalid_field_values(self, key, value, message, mock_hook):
mock_hook.return_value.create_new_function.return_value = True
body = deepcopy(VALID_BODY)
body[key] = value
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudFunctionDeployFunctionOperator(
project_id="test_project_id", location="test_region", body=body, task_id="id"
)
op.execute(None)
- err = cm.exception
- self.assertIn(message, str(err))
+ err = ctx.value
+ assert message in str(err)
mock_hook.assert_called_once_with(
api_version='v1',
gcp_conn_id='google_cloud_default',
@@ -369,7 +370,7 @@ def test_invalid_source_code_union_field(self, source_code, message):
body.pop('sourceArchiveUrl', None)
zip_path = source_code.pop('zip_path', None)
body.update(source_code)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudFunctionDeployFunctionOperator(
project_id="test_project_id",
location="test_region",
@@ -378,8 +379,8 @@ def test_invalid_source_code_union_field(self, source_code, message):
zip_path=zip_path,
)
op.execute(None)
- err = cm.exception
- self.assertIn(message, str(err))
+ err = ctx.value
+ assert message in str(err)
# fmt: off
@parameterized.expand([
@@ -478,7 +479,7 @@ def test_invalid_trigger_union_field(self, trigger, message, mock_hook):
body.pop('httpsTrigger', None)
body.pop('eventTrigger', None)
body.update(trigger)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = CloudFunctionDeployFunctionOperator(
project_id="test_project_id",
location="test_region",
@@ -486,8 +487,8 @@ def test_invalid_trigger_union_field(self, trigger, message, mock_hook):
task_id="id",
)
op.execute(None)
- err = cm.exception
- self.assertIn(message, str(err))
+ err = ctx.value
+ assert message in str(err)
mock_hook.assert_called_once_with(
api_version='v1',
gcp_conn_id='google_cloud_default',
@@ -607,7 +608,7 @@ def test_delete_execute(self, mock_hook):
mock_hook.return_value.delete_function.assert_called_once_with(
'projects/project_name/locations/project_location/functions/function_name'
)
- self.assertEqual(result['name'], self._FUNCTION_NAME)
+ assert result['name'] == self._FUNCTION_NAME
@mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook')
def test_correct_name(self, mock_hook):
@@ -623,20 +624,20 @@ def test_correct_name(self, mock_hook):
@mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook')
def test_invalid_name(self, mock_hook):
- with self.assertRaises(AttributeError) as cm:
+ with pytest.raises(AttributeError) as ctx:
op = CloudFunctionDeleteFunctionOperator(name="invalid_name", task_id="id")
op.execute(None)
- err = cm.exception
- self.assertEqual(str(err), f'Parameter name must match pattern: {FUNCTION_NAME_PATTERN}')
+ err = ctx.value
+ assert str(err) == f'Parameter name must match pattern: {FUNCTION_NAME_PATTERN}'
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook')
def test_empty_name(self, mock_hook):
mock_hook.return_value.delete_function.return_value = self._DELETE_FUNCTION_EXPECTED
- with self.assertRaises(AttributeError) as cm:
+ with pytest.raises(AttributeError) as ctx:
CloudFunctionDeleteFunctionOperator(name="", task_id="id")
- err = cm.exception
- self.assertEqual(str(err), 'Empty parameter: name')
+ err = ctx.value
+ assert str(err) == 'Empty parameter: name'
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.operators.functions.CloudFunctionsHook')
@@ -663,7 +664,7 @@ def test_non_404_gcf_error_bubbled_up(self, mock_hook):
side_effect=HttpError(resp=resp, content=b'error')
)
- with self.assertRaises(HttpError):
+ with pytest.raises(HttpError):
op.execute(None)
mock_hook.assert_called_once_with(
diff --git a/tests/providers/google/cloud/operators/test_gcs.py b/tests/providers/google/cloud/operators/test_gcs.py
index b41cf97710a6f..1c9954d297920 100644
--- a/tests/providers/google/cloud/operators/test_gcs.py
+++ b/tests/providers/google/cloud/operators/test_gcs.py
@@ -149,7 +149,7 @@ def test_execute(self, mock_hook):
mock_hook.return_value.list.assert_called_once_with(
bucket_name=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
)
- self.assertEqual(sorted(files), sorted(MOCK_FILES))
+ assert sorted(files) == sorted(MOCK_FILES)
class TestGCSFileTransformOperator(unittest.TestCase):
diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py
index 2fc4dd8e237a4..22920d1493c56 100644
--- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py
@@ -21,6 +21,7 @@
from unittest import mock
from unittest.mock import PropertyMock
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -75,7 +76,7 @@ def test_create_execute(self, body, mock_hook):
)
@mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook')
def test_create_execute_error_body(self, body, mock_hook):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
GKECreateClusterOperator(
project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, body=body, task_id=PROJECT_TASK_ID
)
@@ -83,13 +84,13 @@ def test_create_execute_error_body(self, body, mock_hook):
# pylint: disable=missing-kwoa
@mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook')
def test_create_execute_error_project_id(self, mock_hook):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
GKECreateClusterOperator(location=PROJECT_LOCATION, body=PROJECT_BODY, task_id=PROJECT_TASK_ID)
# pylint: disable=no-value-for-parameter
@mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook')
def test_create_execute_error_location(self, mock_hook):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
GKECreateClusterOperator(
project_id=TEST_GCP_PROJECT_ID, body=PROJECT_BODY, task_id=PROJECT_TASK_ID
)
@@ -111,13 +112,13 @@ def test_delete_execute(self, mock_hook):
# pylint: disable=no-value-for-parameter
@mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook')
def test_delete_execute_error_project_id(self, mock_hook):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
GKEDeleteClusterOperator(location=PROJECT_LOCATION, name=CLUSTER_NAME, task_id=PROJECT_TASK_ID)
# pylint: disable=missing-kwoa
@mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook')
def test_delete_execute_error_cluster_name(self, mock_hook):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
GKEDeleteClusterOperator(
project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, task_id=PROJECT_TASK_ID
)
@@ -125,7 +126,7 @@ def test_delete_execute_error_cluster_name(self, mock_hook):
# pylint: disable=missing-kwoa
@mock.patch('airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook')
def test_delete_execute_error_location(self, mock_hook):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
GKEDeleteClusterOperator(
project_id=TEST_GCP_PROJECT_ID, name=CLUSTER_NAME, task_id=PROJECT_TASK_ID
)
@@ -144,9 +145,7 @@ def setUp(self):
)
def test_template_fields(self):
- self.assertTrue(
- set(KubernetesPodOperator.template_fields).issubset(GKEStartPodOperator.template_fields)
- )
+ assert set(KubernetesPodOperator.template_fields).issubset(GKEStartPodOperator.template_fields)
# pylint: disable=unused-argument
@mock.patch.dict(os.environ, {})
@@ -187,7 +186,7 @@ def test_execute(self, file_mock, mock_execute_in_subprocess, mock_gcp_hook, exe
]
)
- self.assertEqual(self.gke_op.config_file, FILE_NAME)
+ assert self.gke_op.config_file == FILE_NAME
# pylint: disable=unused-argument
@mock.patch.dict(os.environ, {})
@@ -232,4 +231,4 @@ def test_execute_with_internal_ip(
]
)
- self.assertEqual(self.gke_op.config_file, FILE_NAME)
+ assert self.gke_op.config_file == FILE_NAME
diff --git a/tests/providers/google/cloud/operators/test_life_sciences.py b/tests/providers/google/cloud/operators/test_life_sciences.py
index cc08e13ff8ff8..1e86bd6c44c66 100644
--- a/tests/providers/google/cloud/operators/test_life_sciences.py
+++ b/tests/providers/google/cloud/operators/test_life_sciences.py
@@ -43,7 +43,7 @@ def test_executes(self, mock_hook):
task_id='task-id', body=TEST_BODY, location=TEST_LOCATION, project_id=TEST_PROJECT_ID
)
result = operator.execute(None)
- self.assertEqual(result, TEST_OPERATION)
+ assert result == TEST_OPERATION
@mock.patch("airflow.providers.google.cloud.operators.life_sciences.LifeSciencesHook")
def test_executes_without_project_id(self, mock_hook):
@@ -55,4 +55,4 @@ def test_executes_without_project_id(self, mock_hook):
location=TEST_LOCATION,
)
result = operator.execute(None)
- self.assertEqual(result, TEST_OPERATION)
+ assert result == TEST_OPERATION
diff --git a/tests/providers/google/cloud/operators/test_mlengine.py b/tests/providers/google/cloud/operators/test_mlengine.py
index 719dae9314704..d67c8a2b9986e 100644
--- a/tests/providers/google/cloud/operators/test_mlengine.py
+++ b/tests/providers/google/cloud/operators/test_mlengine.py
@@ -21,6 +21,7 @@
from unittest.mock import ANY, MagicMock, patch
import httplib2
+import pytest
from googleapiclient.errors import HttpError
from airflow.exceptions import AirflowException
@@ -145,7 +146,7 @@ def test_success_with_model(self, mock_hook):
},
use_existing_job_fn=ANY,
)
- self.assertEqual(success_message['predictionOutput'], prediction_output)
+ assert success_message['predictionOutput'] == prediction_output
@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
def test_success_with_version(self, mock_hook):
@@ -184,7 +185,7 @@ def test_success_with_version(self, mock_hook):
job={'jobId': 'test_prediction', 'predictionInput': input_with_version},
use_existing_job_fn=ANY,
)
- self.assertEqual(success_message['predictionOutput'], prediction_output)
+ assert success_message['predictionOutput'] == prediction_output
@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
def test_success_with_uri(self, mock_hook):
@@ -222,48 +223,43 @@ def test_success_with_uri(self, mock_hook):
job={'jobId': 'test_prediction', 'predictionInput': input_with_uri},
use_existing_job_fn=ANY,
)
- self.assertEqual(success_message['predictionOutput'], prediction_output)
+ assert success_message['predictionOutput'] == prediction_output
def test_invalid_model_origin(self):
# Test that both uri and model is given
task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
task_args['uri'] = 'gs://fake-uri/saved_model'
task_args['model_name'] = 'fake_model'
- with self.assertRaises(AirflowException) as context:
+ with pytest.raises(AirflowException) as ctx:
MLEngineStartBatchPredictionJobOperator(**task_args).execute(None)
- self.assertEqual(
- 'Ambiguous model origin: Both uri and ' 'model/version name are provided.', str(context.exception)
- )
+ assert 'Ambiguous model origin: Both uri and ' 'model/version name are provided.' == str(ctx.value)
# Test that both uri and model/version is given
task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
task_args['uri'] = 'gs://fake-uri/saved_model'
task_args['model_name'] = 'fake_model'
task_args['version_name'] = 'fake_version'
- with self.assertRaises(AirflowException) as context:
+ with pytest.raises(AirflowException) as ctx:
MLEngineStartBatchPredictionJobOperator(**task_args).execute(None)
- self.assertEqual(
- 'Ambiguous model origin: Both uri and ' 'model/version name are provided.', str(context.exception)
- )
+ assert 'Ambiguous model origin: Both uri and ' 'model/version name are provided.' == str(ctx.value)
# Test that a version is given without a model
task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
task_args['version_name'] = 'bare_version'
- with self.assertRaises(AirflowException) as context:
+ with pytest.raises(AirflowException) as ctx:
MLEngineStartBatchPredictionJobOperator(**task_args).execute(None)
- self.assertEqual(
- 'Missing model: Batch prediction expects a model ' 'name when a version name is provided.',
- str(context.exception),
+ assert (
+ 'Missing model: Batch prediction expects a model '
+ 'name when a version name is provided.' == str(ctx.value)
)
# Test that none of uri, model, model/version is given
task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
- with self.assertRaises(AirflowException) as context:
+ with pytest.raises(AirflowException) as ctx:
MLEngineStartBatchPredictionJobOperator(**task_args).execute(None)
- self.assertEqual(
+ assert (
'Missing model origin: Batch prediction expects a '
- 'model, a model & version combination, or a URI to a savedModel.',
- str(context.exception),
+ 'model, a model & version combination, or a URI to a savedModel.' == str(ctx.value)
)
@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
@@ -277,7 +273,7 @@ def test_http_error(self, mock_hook):
resp=httplib2.Response({'status': http_error_code}), content=b'Forbidden'
)
- with self.assertRaises(HttpError) as context:
+ with pytest.raises(HttpError) as ctx:
prediction_task = MLEngineStartBatchPredictionJobOperator(
job_id='test_prediction',
project_id='test-project',
@@ -300,7 +296,7 @@ def test_http_error(self, mock_hook):
'test-project', {'jobId': 'test_prediction', 'predictionInput': input_with_model}, ANY
)
- self.assertEqual(http_error_code, context.exception.resp.status)
+ assert http_error_code == ctx.value.resp.status
@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
def test_failed_job_error(self, mock_hook):
@@ -309,10 +305,10 @@ def test_failed_job_error(self, mock_hook):
task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
task_args['uri'] = 'a uri'
- with self.assertRaises(RuntimeError) as context:
+ with pytest.raises(RuntimeError) as ctx:
MLEngineStartBatchPredictionJobOperator(**task_args).execute(None)
- self.assertEqual('A failure message', str(context.exception))
+ assert 'A failure message' == str(ctx.value)
class TestMLEngineStartTrainingJobOperator(unittest.TestCase):
@@ -359,7 +355,7 @@ def test_success_create_training_job(self, mock_hook):
impersonation_chain=None,
)
# Make sure only 'create_job' is invoked on hook instance
- self.assertEqual(len(hook_instance.mock_calls), 1)
+ assert len(hook_instance.mock_calls) == 1
hook_instance.create_job.assert_called_once_with(
project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY
)
@@ -402,7 +398,7 @@ def test_success_create_training_job_with_master_config(self, mock_hook):
impersonation_chain=None,
)
# Make sure only 'create_job' is invoked on hook instance
- self.assertEqual(len(hook_instance.mock_calls), 1)
+ assert len(hook_instance.mock_calls) == 1
hook_instance.create_job.assert_called_once_with(
project_id='test-project', job=training_input, use_existing_job_fn=ANY
)
@@ -446,7 +442,7 @@ def test_success_create_training_job_with_master_image(self, hook):
delegate_to=None,
impersonation_chain=None,
)
- self.assertEqual(len(hook_instance.mock_calls), 1)
+ assert len(hook_instance.mock_calls) == 1
hook_instance.create_job.assert_called_once_with(
project_id='test-project',
job=request,
@@ -481,7 +477,7 @@ def test_success_create_training_job_with_optional_args(self, mock_hook):
impersonation_chain=None,
)
# Make sure only 'create_job' is invoked on hook instance
- self.assertEqual(len(hook_instance.mock_calls), 1)
+ assert len(hook_instance.mock_calls) == 1
hook_instance.create_job.assert_called_once_with(
project_id='test-project', job=training_input, use_existing_job_fn=ANY
)
@@ -494,7 +490,7 @@ def test_http_error(self, mock_hook):
resp=httplib2.Response({'status': http_error_code}), content=b'Forbidden'
)
- with self.assertRaises(HttpError) as context:
+ with pytest.raises(HttpError) as ctx:
training_op = MLEngineStartTrainingJobOperator(**self.TRAINING_DEFAULT_ARGS)
training_op.execute(None)
@@ -504,11 +500,11 @@ def test_http_error(self, mock_hook):
impersonation_chain=None,
)
# Make sure only 'create_job' is invoked on hook instance
- self.assertEqual(len(hook_instance.mock_calls), 1)
+ assert len(hook_instance.mock_calls) == 1
hook_instance.create_job.assert_called_once_with(
project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY
)
- self.assertEqual(http_error_code, context.exception.resp.status)
+ assert http_error_code == ctx.value.resp.status
@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
def test_failed_job_error(self, mock_hook):
@@ -518,7 +514,7 @@ def test_failed_job_error(self, mock_hook):
hook_instance = mock_hook.return_value
hook_instance.create_job.return_value = failure_response
- with self.assertRaises(RuntimeError) as context:
+ with pytest.raises(RuntimeError) as ctx:
training_op = MLEngineStartTrainingJobOperator(**self.TRAINING_DEFAULT_ARGS)
training_op.execute(None)
@@ -528,11 +524,11 @@ def test_failed_job_error(self, mock_hook):
impersonation_chain=None,
)
# Make sure only 'create_job' is invoked on hook instance
- self.assertEqual(len(hook_instance.mock_calls), 1)
+ assert len(hook_instance.mock_calls) == 1
hook_instance.create_job.assert_called_once_with(
project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY
)
- self.assertEqual('A failure message', str(context.exception))
+ assert 'A failure message' == str(ctx.value)
@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
def test_console_extra_link(self, mock_hook):
@@ -551,15 +547,12 @@ def test_console_extra_link(self, mock_hook):
}
ti.xcom_push(key='gcp_metadata', value=gcp_metadata)
- self.assertEqual(
- f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}",
- training_op.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name),
+ assert (
+ f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}"
+ == training_op.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name)
)
- self.assertEqual(
- '',
- training_op.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name),
- )
+ assert '' == training_op.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name)
def test_console_extra_link_serialized_field(self):
with self.dag:
@@ -569,13 +562,12 @@ def test_console_extra_link_serialized_field(self):
simple_task = dag.task_dict[self.TRAINING_DEFAULT_ARGS['task_id']]
# Check Serialized version of operator link
- self.assertEqual(
- serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
- [{"airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink": {}}],
- )
+ assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ {"airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink": {}}
+ ]
# Check DeSerialized version of operator link
- self.assertIsInstance(list(simple_task.operator_extra_links)[0], AIPlatformConsoleLink)
+ assert isinstance(list(simple_task.operator_extra_links)[0], AIPlatformConsoleLink)
job_id = self.TRAINING_DEFAULT_ARGS['job_id']
project_id = self.TRAINING_DEFAULT_ARGS['project_id']
@@ -590,15 +582,12 @@ def test_console_extra_link_serialized_field(self):
)
ti.xcom_push(key='gcp_metadata', value=gcp_metadata)
- self.assertEqual(
- f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}",
- simple_task.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name),
+ assert (
+ f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}"
+ == simple_task.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name)
)
- self.assertEqual(
- '',
- simple_task.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name),
- )
+ assert '' == simple_task.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name)
class TestMLEngineTrainingCancelJobOperator(unittest.TestCase):
@@ -624,7 +613,7 @@ def test_success_cancel_training_job(self, mock_hook):
impersonation_chain=None,
)
# Make sure only 'cancel_job' is invoked on hook instance
- self.assertEqual(len(hook_instance.mock_calls), 1)
+ assert len(hook_instance.mock_calls) == 1
hook_instance.cancel_job.assert_called_once_with(
project_id=self.TRAINING_DEFAULT_ARGS['project_id'], job_id=self.TRAINING_DEFAULT_ARGS['job_id']
)
@@ -637,7 +626,7 @@ def test_http_error(self, mock_hook):
resp=httplib2.Response({'status': http_error_code}), content=b'Forbidden'
)
- with self.assertRaises(HttpError) as context:
+ with pytest.raises(HttpError) as ctx:
cancel_training_op = MLEngineTrainingCancelJobOperator(**self.TRAINING_DEFAULT_ARGS)
cancel_training_op.execute(None)
@@ -647,11 +636,11 @@ def test_http_error(self, mock_hook):
impersonation_chain=None,
)
# Make sure only 'create_job' is invoked on hook instance
- self.assertEqual(len(hook_instance.mock_calls), 1)
+ assert len(hook_instance.mock_calls) == 1
hook_instance.cancel_job.assert_called_once_with(
project_id=self.TRAINING_DEFAULT_ARGS['project_id'], job_id=self.TRAINING_DEFAULT_ARGS['job_id']
)
- self.assertEqual(http_error_code, context.exception.resp.status)
+ assert http_error_code == ctx.value.resp.status
class TestMLEngineModelOperator(unittest.TestCase):
@@ -700,7 +689,7 @@ def test_success_get_model(self, mock_hook):
mock_hook.return_value.get_model.assert_called_once_with(
project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME
)
- self.assertEqual(mock_hook.return_value.get_model.return_value, result)
+ assert mock_hook.return_value.get_model.return_value == result
@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
def test_fail(self, mock_hook):
@@ -712,7 +701,7 @@ def test_fail(self, mock_hook):
gcp_conn_id=TEST_GCP_CONN_ID,
delegate_to=TEST_DELEGATE_TO,
)
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
task.execute(None)
@@ -762,7 +751,7 @@ def test_success_get_model(self, mock_hook):
mock_hook.return_value.get_model.assert_called_once_with(
project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME
)
- self.assertEqual(mock_hook.return_value.get_model.return_value, result)
+ assert mock_hook.return_value.get_model.return_value == result
class TestMLEngineDeleteModelOperator(unittest.TestCase):
@@ -812,7 +801,7 @@ def test_success_create_version(self, mock_hook):
impersonation_chain=None,
)
# Make sure only 'create_version' is invoked on hook instance
- self.assertEqual(len(hook_instance.mock_calls), 1)
+ assert len(hook_instance.mock_calls) == 1
hook_instance.create_version.assert_called_once_with(
project_id='test-project', model_name='test-model', version_spec=TEST_VERSION
)
@@ -843,7 +832,7 @@ def test_success(self, mock_hook):
)
def test_missing_model_name(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
MLEngineCreateVersionOperator(
task_id="task-id",
project_id=TEST_PROJECT_ID,
@@ -854,7 +843,7 @@ def test_missing_model_name(self):
)
def test_missing_version(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
MLEngineCreateVersionOperator(
task_id="task-id",
project_id=TEST_PROJECT_ID,
@@ -890,7 +879,7 @@ def test_success(self, mock_hook):
)
def test_missing_model_name(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
MLEngineSetDefaultVersionOperator(
task_id="task-id",
project_id=TEST_PROJECT_ID,
@@ -901,7 +890,7 @@ def test_missing_model_name(self):
)
def test_missing_version_name(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
MLEngineSetDefaultVersionOperator(
task_id="task-id",
project_id=TEST_PROJECT_ID,
@@ -937,7 +926,7 @@ def test_success(self, mock_hook):
)
def test_missing_model_name(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
MLEngineListVersionsOperator(
task_id="task-id",
project_id=TEST_PROJECT_ID,
@@ -972,7 +961,7 @@ def test_success(self, mock_hook):
)
def test_missing_version_name(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
MLEngineDeleteVersionOperator(
task_id="task-id",
project_id=TEST_PROJECT_ID,
@@ -983,7 +972,7 @@ def test_missing_version_name(self):
)
def test_missing_model_name(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
MLEngineDeleteVersionOperator(
task_id="task-id",
project_id=TEST_PROJECT_ID,
diff --git a/tests/providers/google/cloud/operators/test_mlengine_utils.py b/tests/providers/google/cloud/operators/test_mlengine_utils.py
index eafa77208d9cc..539ee608dec0f 100644
--- a/tests/providers/google/cloud/operators/test_mlengine_utils.py
+++ b/tests/providers/google/cloud/operators/test_mlengine_utils.py
@@ -20,6 +20,8 @@
from unittest import mock
from unittest.mock import ANY, patch
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
from airflow.providers.google.cloud.utils import mlengine_operator_utils
@@ -102,7 +104,7 @@ def test_successful_run(self):
},
use_existing_job_fn=ANY,
)
- self.assertEqual(success_message['predictionOutput'], result)
+ assert success_message['predictionOutput'] == result
with patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') as mock_dataflow_hook:
hook_instance = mock_dataflow_hook.return_value
@@ -141,7 +143,7 @@ def test_successful_run(self):
hook_instance.download.assert_called_once_with(
'legal-bucket', 'fake-output-path/prediction.summary.json'
)
- self.assertEqual('err=0.9', result)
+ assert 'err=0.9' == result
def test_failures(self):
def create_test_dag(dag_id):
@@ -169,12 +171,12 @@ def create_test_dag(dag_id):
'validate_fn': (lambda x: 'err=%.1f' % x['err']),
}
- with self.assertRaisesRegex(AirflowException, 'Missing model origin'):
+ with pytest.raises(AirflowException, match='Missing model origin'):
mlengine_operator_utils.create_evaluate_ops(
dag=create_test_dag('test_dag_1'), **other_params_but_models
)
- with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'):
+ with pytest.raises(AirflowException, match='Ambiguous model origin'):
mlengine_operator_utils.create_evaluate_ops(
dag=create_test_dag('test_dag_2'),
model_uri='abc',
@@ -182,7 +184,7 @@ def create_test_dag(dag_id):
**other_params_but_models,
)
- with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'):
+ with pytest.raises(AirflowException, match='Ambiguous model origin'):
mlengine_operator_utils.create_evaluate_ops(
dag=create_test_dag('test_dag_3'),
model_uri='abc',
@@ -190,14 +192,14 @@ def create_test_dag(dag_id):
**other_params_but_models,
)
- with self.assertRaisesRegex(AirflowException, '`metric_fn` param must be callable'):
+ with pytest.raises(AirflowException, match='`metric_fn` param must be callable'):
params = other_params_but_models.copy()
params['metric_fn_and_keys'] = (None, ['abc'])
mlengine_operator_utils.create_evaluate_ops(
dag=create_test_dag('test_dag_4'), model_uri='gs://blah', **params
)
- with self.assertRaisesRegex(AirflowException, '`validate_fn` param must be callable'):
+ with pytest.raises(AirflowException, match='`validate_fn` param must be callable'):
params = other_params_but_models.copy()
params['validate_fn'] = None
mlengine_operator_utils.create_evaluate_ops(
diff --git a/tests/providers/google/cloud/operators/test_natural_language.py b/tests/providers/google/cloud/operators/test_natural_language.py
index a6a0f3d7d9c5a..7cf0473ced314 100644
--- a/tests/providers/google/cloud/operators/test_natural_language.py
+++ b/tests/providers/google/cloud/operators/test_natural_language.py
@@ -52,7 +52,7 @@ def test_minimal_green_path(self, hook_mock):
hook_mock.return_value.analyze_entities.return_value = ANALYZE_ENTITIES_RESPONSE
op = CloudNaturalLanguageAnalyzeEntitiesOperator(task_id="task-id", document=DOCUMENT)
resp = op.execute({})
- self.assertEqual(resp, {})
+ assert resp == {}
class TestCloudLanguageAnalyzeEntitySentimentOperator(unittest.TestCase):
@@ -61,7 +61,7 @@ def test_minimal_green_path(self, hook_mock):
hook_mock.return_value.analyze_entity_sentiment.return_value = ANALYZE_ENTITY_SENTIMENT_RESPONSE
op = CloudNaturalLanguageAnalyzeEntitySentimentOperator(task_id="task-id", document=DOCUMENT)
resp = op.execute({})
- self.assertEqual(resp, {})
+ assert resp == {}
class TestCloudLanguageAnalyzeSentimentOperator(unittest.TestCase):
@@ -70,7 +70,7 @@ def test_minimal_green_path(self, hook_mock):
hook_mock.return_value.analyze_sentiment.return_value = ANALYZE_SENTIMENT_RESPONSE
op = CloudNaturalLanguageAnalyzeSentimentOperator(task_id="task-id", document=DOCUMENT)
resp = op.execute({})
- self.assertEqual(resp, {})
+ assert resp == {}
class TestCloudLanguageClassifyTextOperator(unittest.TestCase):
@@ -79,4 +79,4 @@ def test_minimal_green_path(self, hook_mock):
hook_mock.return_value.classify_text.return_value = CLASSIFY_TEXT_RRESPONSE
op = CloudNaturalLanguageClassifyTextOperator(task_id="task-id", document=DOCUMENT)
resp = op.execute({})
- self.assertEqual(resp, {})
+ assert resp == {}
diff --git a/tests/providers/google/cloud/operators/test_pubsub.py b/tests/providers/google/cloud/operators/test_pubsub.py
index 8e39343ea8201..6abfffa0e0c84 100644
--- a/tests/providers/google/cloud/operators/test_pubsub.py
+++ b/tests/providers/google/cloud/operators/test_pubsub.py
@@ -126,7 +126,7 @@ def test_execute(self, mock_hook):
timeout=None,
metadata=None,
)
- self.assertEqual(response, TEST_SUBSCRIPTION)
+ assert response == TEST_SUBSCRIPTION
@mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook')
def test_execute_different_project_ids(self, mock_hook):
@@ -160,7 +160,7 @@ def test_execute_different_project_ids(self, mock_hook):
timeout=None,
metadata=None,
)
- self.assertEqual(response, TEST_SUBSCRIPTION)
+ assert response == TEST_SUBSCRIPTION
@mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook')
def test_execute_no_subscription(self, mock_hook):
@@ -189,7 +189,7 @@ def test_execute_no_subscription(self, mock_hook):
timeout=None,
metadata=None,
)
- self.assertEqual(response, TEST_SUBSCRIPTION)
+ assert response == TEST_SUBSCRIPTION
class TestPubSubSubscriptionDeleteOperator(unittest.TestCase):
@@ -251,7 +251,7 @@ def test_execute_no_messages(self, mock_hook):
)
mock_hook.return_value.pull.return_value = []
- self.assertEqual([], operator.execute({}))
+ assert [] == operator.execute({})
@mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook')
def test_execute_with_ack_messages(self, mock_hook):
@@ -266,7 +266,7 @@ def test_execute_with_ack_messages(self, mock_hook):
generated_dicts = self._generate_dicts(5)
mock_hook.return_value.pull.return_value = generated_messages
- self.assertEqual(generated_dicts, operator.execute({}))
+ assert generated_dicts == operator.execute({})
mock_hook.return_value.acknowledge.assert_called_once_with(
project_id=TEST_PROJECT,
subscription=TEST_SUBSCRIPTION,
diff --git a/tests/providers/google/cloud/operators/test_spanner.py b/tests/providers/google/cloud/operators/test_spanner.py
index 4daccdb0e3936..6347fd24be4b9 100644
--- a/tests/providers/google/cloud/operators/test_spanner.py
+++ b/tests/providers/google/cloud/operators/test_spanner.py
@@ -18,6 +18,7 @@
import unittest
from unittest import mock
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -68,7 +69,7 @@ def test_instance_create(self, mock_hook):
display_name=DISPLAY_NAME,
)
mock_hook.return_value.update_instance.assert_not_called()
- self.assertIsNone(result)
+ assert result is None
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_instance_create_missing_project_id(self, mock_hook):
@@ -93,7 +94,7 @@ def test_instance_create_missing_project_id(self, mock_hook):
display_name=DISPLAY_NAME,
)
mock_hook.return_value.update_instance.assert_not_called()
- self.assertIsNone(result)
+ assert result is None
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_instance_update(self, mock_hook):
@@ -119,7 +120,7 @@ def test_instance_update(self, mock_hook):
display_name=DISPLAY_NAME,
)
mock_hook.return_value.create_instance.assert_not_called()
- self.assertIsNone(result)
+ assert result is None
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_instance_update_missing_project_id(self, mock_hook):
@@ -144,7 +145,7 @@ def test_instance_update_missing_project_id(self, mock_hook):
display_name=DISPLAY_NAME,
)
mock_hook.return_value.create_instance.assert_not_called()
- self.assertIsNone(result)
+ assert result is None
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_instance_create_aborts_and_succeeds_if_instance_exists(self, mock_hook):
@@ -163,7 +164,7 @@ def test_instance_create_aborts_and_succeeds_if_instance_exists(self, mock_hook)
impersonation_chain=None,
)
mock_hook.return_value.create_instance.assert_not_called()
- self.assertIsNone(result)
+ assert result is None
@parameterized.expand(
[
@@ -173,7 +174,7 @@ def test_instance_create_aborts_and_succeeds_if_instance_exists(self, mock_hook)
)
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_instance_create_ex_if_param_missing(self, project_id, instance_id, exp_msg, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
SpannerDeployInstanceOperator(
project_id=project_id,
instance_id=instance_id,
@@ -182,8 +183,8 @@ def test_instance_create_ex_if_param_missing(self, project_id, instance_id, exp_
display_name=DISPLAY_NAME,
task_id="id",
)
- err = cm.exception
- self.assertIn(f"The required parameter '{exp_msg}' is empty", str(err))
+ err = ctx.value
+ assert f"The required parameter '{exp_msg}' is empty" in str(err)
mock_hook.assert_not_called()
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
@@ -198,7 +199,7 @@ def test_instance_delete(self, mock_hook):
mock_hook.return_value.delete_instance.assert_called_once_with(
project_id=PROJECT_ID, instance_id=INSTANCE_ID
)
- self.assertTrue(result)
+ assert result
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_instance_delete_missing_project_id(self, mock_hook):
@@ -212,7 +213,7 @@ def test_instance_delete_missing_project_id(self, mock_hook):
mock_hook.return_value.delete_instance.assert_called_once_with(
project_id=None, instance_id=INSTANCE_ID
)
- self.assertTrue(result)
+ assert result
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_instance_delete_aborts_and_succeeds_if_instance_does_not_exist(self, mock_hook):
@@ -224,7 +225,7 @@ def test_instance_delete_aborts_and_succeeds_if_instance_does_not_exist(self, mo
impersonation_chain=None,
)
mock_hook.return_value.delete_instance.assert_not_called()
- self.assertTrue(result)
+ assert result
@parameterized.expand(
[
@@ -234,10 +235,10 @@ def test_instance_delete_aborts_and_succeeds_if_instance_does_not_exist(self, mo
)
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_instance_delete_ex_if_param_missing(self, project_id, instance_id, exp_msg, mock_hook):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
SpannerDeleteInstanceOperator(project_id=project_id, instance_id=instance_id, task_id="id")
- err = cm.exception
- self.assertIn(f"The required parameter '{exp_msg}' is empty", str(err))
+ err = ctx.value
+ assert f"The required parameter '{exp_msg}' is empty" in str(err)
mock_hook.assert_not_called()
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
@@ -258,7 +259,7 @@ def test_instance_query(self, mock_hook):
mock_hook.return_value.execute_dml.assert_called_once_with(
project_id=PROJECT_ID, instance_id=INSTANCE_ID, database_id=DB_ID, queries=[INSERT_QUERY]
)
- self.assertIsNone(result)
+ assert result is None
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_instance_query_missing_project_id(self, mock_hook):
@@ -274,7 +275,7 @@ def test_instance_query_missing_project_id(self, mock_hook):
mock_hook.return_value.execute_dml.assert_called_once_with(
project_id=None, instance_id=INSTANCE_ID, database_id=DB_ID, queries=[INSERT_QUERY]
)
- self.assertIsNone(result)
+ assert result is None
@parameterized.expand(
[
@@ -288,7 +289,7 @@ def test_instance_query_missing_project_id(self, mock_hook):
def test_instance_query_ex_if_param_missing(
self, project_id, instance_id, database_id, query, exp_msg, mock_hook
):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
SpannerQueryDatabaseInstanceOperator(
project_id=project_id,
instance_id=instance_id,
@@ -296,8 +297,8 @@ def test_instance_query_ex_if_param_missing(
query=query,
task_id="id",
)
- err = cm.exception
- self.assertIn(f"The required parameter '{exp_msg}' is empty", str(err))
+ err = ctx.value
+ assert f"The required parameter '{exp_msg}' is empty" in str(err)
mock_hook.assert_not_called()
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
@@ -360,7 +361,7 @@ def test_database_create(self, mock_hook):
project_id=PROJECT_ID, instance_id=INSTANCE_ID, database_id=DB_ID, ddl_statements=DDL_STATEMENTS
)
mock_hook.return_value.update_database.assert_not_called()
- self.assertTrue(result)
+ assert result
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_database_create_missing_project_id(self, mock_hook):
@@ -377,7 +378,7 @@ def test_database_create_missing_project_id(self, mock_hook):
project_id=None, instance_id=INSTANCE_ID, database_id=DB_ID, ddl_statements=DDL_STATEMENTS
)
mock_hook.return_value.update_database.assert_not_called()
- self.assertTrue(result)
+ assert result
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_database_create_with_pre_existing_db(self, mock_hook):
@@ -396,7 +397,7 @@ def test_database_create_with_pre_existing_db(self, mock_hook):
)
mock_hook.return_value.create_database.assert_not_called()
mock_hook.return_value.update_database.assert_not_called()
- self.assertTrue(result)
+ assert result
@parameterized.expand(
[
@@ -409,7 +410,7 @@ def test_database_create_with_pre_existing_db(self, mock_hook):
def test_database_create_ex_if_param_missing(
self, project_id, instance_id, database_id, ddl_statements, exp_msg, mock_hook
):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
SpannerDeployDatabaseInstanceOperator(
project_id=project_id,
instance_id=instance_id,
@@ -417,8 +418,8 @@ def test_database_create_ex_if_param_missing(
ddl_statements=ddl_statements,
task_id="id",
)
- err = cm.exception
- self.assertIn(f"The required parameter '{exp_msg}' is empty", str(err))
+ err = ctx.value
+ assert f"The required parameter '{exp_msg}' is empty" in str(err)
mock_hook.assert_not_called()
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
@@ -443,7 +444,7 @@ def test_database_update(self, mock_hook):
ddl_statements=DDL_STATEMENTS,
operation_id=None,
)
- self.assertTrue(result)
+ assert result
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_database_update_missing_project_id(self, mock_hook):
@@ -463,7 +464,7 @@ def test_database_update_missing_project_id(self, mock_hook):
ddl_statements=DDL_STATEMENTS,
operation_id=None,
)
- self.assertTrue(result)
+ assert result
@parameterized.expand(
[
@@ -476,7 +477,7 @@ def test_database_update_missing_project_id(self, mock_hook):
def test_database_update_ex_if_param_missing(
self, project_id, instance_id, database_id, ddl_statements, exp_msg, mock_hook
):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
SpannerUpdateDatabaseInstanceOperator(
project_id=project_id,
instance_id=instance_id,
@@ -484,14 +485,14 @@ def test_database_update_ex_if_param_missing(
ddl_statements=ddl_statements,
task_id="id",
)
- err = cm.exception
- self.assertIn(f"The required parameter '{exp_msg}' is empty", str(err))
+ err = ctx.value
+ assert f"The required parameter '{exp_msg}' is empty" in str(err)
mock_hook.assert_not_called()
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_database_update_ex_if_database_not_exist(self, mock_hook):
mock_hook.return_value.get_database.return_value = None
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op = SpannerUpdateDatabaseInstanceOperator(
project_id=PROJECT_ID,
instance_id=INSTANCE_ID,
@@ -500,11 +501,10 @@ def test_database_update_ex_if_database_not_exist(self, mock_hook):
task_id="id",
)
op.execute(None)
- err = cm.exception
- self.assertIn(
+ err = ctx.value
+ assert (
"The Cloud Spanner database 'db1' in project 'project-id' and "
- "instance 'instance-id' is missing",
- str(err),
+ "instance 'instance-id' is missing" in str(err)
)
mock_hook.assert_called_once_with(
gcp_conn_id="google_cloud_default",
@@ -525,7 +525,7 @@ def test_database_delete(self, mock_hook):
mock_hook.return_value.delete_database.assert_called_once_with(
project_id=PROJECT_ID, instance_id=INSTANCE_ID, database_id=DB_ID
)
- self.assertTrue(result)
+ assert result
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_database_delete_missing_project_id(self, mock_hook):
@@ -539,7 +539,7 @@ def test_database_delete_missing_project_id(self, mock_hook):
mock_hook.return_value.delete_database.assert_called_once_with(
project_id=None, instance_id=INSTANCE_ID, database_id=DB_ID
)
- self.assertTrue(result)
+ assert result
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_database_delete_exits_and_succeeds_if_database_does_not_exist(self, mock_hook):
@@ -553,7 +553,7 @@ def test_database_delete_exits_and_succeeds_if_database_does_not_exist(self, moc
impersonation_chain=None,
)
mock_hook.return_value.delete_database.assert_not_called()
- self.assertTrue(result)
+ assert result
@parameterized.expand(
[
@@ -566,7 +566,7 @@ def test_database_delete_exits_and_succeeds_if_database_does_not_exist(self, moc
def test_database_delete_ex_if_param_missing(
self, project_id, instance_id, database_id, ddl_statements, exp_msg, mock_hook
):
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
SpannerDeleteDatabaseInstanceOperator(
project_id=project_id,
instance_id=instance_id,
@@ -574,6 +574,6 @@ def test_database_delete_ex_if_param_missing(
ddl_statements=ddl_statements,
task_id="id",
)
- err = cm.exception
- self.assertIn(f"The required parameter '{exp_msg}' is empty", str(err))
+ err = ctx.value
+ assert f"The required parameter '{exp_msg}' is empty" in str(err)
mock_hook.assert_not_called()
diff --git a/tests/providers/google/cloud/operators/test_speech_to_text.py b/tests/providers/google/cloud/operators/test_speech_to_text.py
index c9325ebe5924e..c13550e41b0f9 100644
--- a/tests/providers/google/cloud/operators/test_speech_to_text.py
+++ b/tests/providers/google/cloud/operators/test_speech_to_text.py
@@ -19,6 +19,8 @@
import unittest
from unittest.mock import MagicMock, Mock, patch
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.operators.speech_to_text import CloudSpeechToTextRecognizeSpeechOperator
@@ -55,24 +57,24 @@ def test_recognize_speech_green_path(self, mock_hook):
def test_missing_config(self, mock_hook):
mock_hook.return_value.recognize_speech.return_value = True
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
CloudSpeechToTextRecognizeSpeechOperator( # pylint: disable=missing-kwoa
project_id=PROJECT_ID, gcp_conn_id=GCP_CONN_ID, audio=AUDIO, task_id="id"
).execute(context={"task_instance": Mock()})
- err = e.exception
- self.assertIn("config", str(err))
+ err = ctx.value
+ assert "config" in str(err)
mock_hook.assert_not_called()
@patch("airflow.providers.google.cloud.operators.speech_to_text.CloudSpeechToTextHook")
def test_missing_audio(self, mock_hook):
mock_hook.return_value.recognize_speech.return_value = True
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
CloudSpeechToTextRecognizeSpeechOperator( # pylint: disable=missing-kwoa
project_id=PROJECT_ID, gcp_conn_id=GCP_CONN_ID, config=CONFIG, task_id="id"
).execute(context={"task_instance": Mock()})
- err = e.exception
- self.assertIn("audio", str(err))
+ err = ctx.value
+ assert "audio" in str(err)
mock_hook.assert_not_called()
diff --git a/tests/providers/google/cloud/operators/test_stackdriver.py b/tests/providers/google/cloud/operators/test_stackdriver.py
index fdf28dc2b2620..39899d0d1de06 100644
--- a/tests/providers/google/cloud/operators/test_stackdriver.py
+++ b/tests/providers/google/cloud/operators/test_stackdriver.py
@@ -108,7 +108,7 @@ def test_execute(self, mock_hook):
timeout=DEFAULT,
metadata=None,
)
- self.assertEqual([{'name': 'test-name'}], result)
+ assert [{'name': 'test-name'}] == result
class TestStackdriverEnableAlertPoliciesOperator(unittest.TestCase):
@@ -179,7 +179,7 @@ def test_execute(self, mock_hook):
timeout=DEFAULT,
metadata=None,
)
- self.assertEqual([{'name': 'test-123'}], result)
+ assert [{'name': 'test-123'}] == result
class TestStackdriverEnableNotificationChannelsOperator(unittest.TestCase):
diff --git a/tests/providers/google/cloud/operators/test_tasks.py b/tests/providers/google/cloud/operators/test_tasks.py
index ed7691169fe4c..58523bcd66700 100644
--- a/tests/providers/google/cloud/operators/test_tasks.py
+++ b/tests/providers/google/cloud/operators/test_tasks.py
@@ -57,7 +57,7 @@ def test_create_queue(self, mock_hook):
result = operator.execute(context=None)
- self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result)
+ assert {'name': FULL_QUEUE_PATH, 'state': 0} == result
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=None,
@@ -81,7 +81,7 @@ def test_update_queue(self, mock_hook):
result = operator.execute(context=None)
- self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result)
+ assert {'name': FULL_QUEUE_PATH, 'state': 0} == result
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=None,
@@ -106,7 +106,7 @@ def test_get_queue(self, mock_hook):
result = operator.execute(context=None)
- self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result)
+ assert {'name': FULL_QUEUE_PATH, 'state': 0} == result
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=None,
@@ -129,7 +129,7 @@ def test_list_queues(self, mock_hook):
result = operator.execute(context=None)
- self.assertEqual([{'name': FULL_QUEUE_PATH, 'state': 0}], result)
+ assert [{'name': FULL_QUEUE_PATH, 'state': 0}] == result
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=None,
@@ -153,7 +153,7 @@ def test_delete_queue(self, mock_hook):
result = operator.execute(context=None)
- self.assertEqual(None, result)
+ assert result is None
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=None,
@@ -176,7 +176,7 @@ def test_delete_queue(self, mock_hook):
result = operator.execute(context=None)
- self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result)
+ assert {'name': FULL_QUEUE_PATH, 'state': 0} == result
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=None,
@@ -199,7 +199,7 @@ def test_pause_queue(self, mock_hook):
result = operator.execute(context=None)
- self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result)
+ assert {'name': FULL_QUEUE_PATH, 'state': 0} == result
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=None,
@@ -222,7 +222,7 @@ def test_resume_queue(self, mock_hook):
result = operator.execute(context=None)
- self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result)
+ assert {'name': FULL_QUEUE_PATH, 'state': 0} == result
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=None,
@@ -247,16 +247,13 @@ def test_create_task(self, mock_hook):
result = operator.execute(context=None)
- self.assertEqual(
- {
- 'app_engine_http_request': {'body': '', 'headers': {}, 'http_method': 0, 'relative_uri': ''},
- 'dispatch_count': 0,
- 'name': '',
- 'response_count': 0,
- 'view': 0,
- },
- result,
- )
+ assert {
+ 'app_engine_http_request': {'body': '', 'headers': {}, 'http_method': 0, 'relative_uri': ''},
+ 'dispatch_count': 0,
+ 'name': '',
+ 'response_count': 0,
+ 'view': 0,
+ } == result
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=None,
@@ -284,16 +281,13 @@ def test_get_task(self, mock_hook):
result = operator.execute(context=None)
- self.assertEqual(
- {
- 'app_engine_http_request': {'body': '', 'headers': {}, 'http_method': 0, 'relative_uri': ''},
- 'dispatch_count': 0,
- 'name': '',
- 'response_count': 0,
- 'view': 0,
- },
- result,
- )
+ assert {
+ 'app_engine_http_request': {'body': '', 'headers': {}, 'http_method': 0, 'relative_uri': ''},
+ 'dispatch_count': 0,
+ 'name': '',
+ 'response_count': 0,
+ 'view': 0,
+ } == result
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=None,
@@ -318,23 +312,20 @@ def test_list_tasks(self, mock_hook):
result = operator.execute(context=None)
- self.assertEqual(
- [
- {
- 'app_engine_http_request': {
- 'body': '',
- 'headers': {},
- 'http_method': 0,
- 'relative_uri': '',
- },
- 'dispatch_count': 0,
- 'name': '',
- 'response_count': 0,
- 'view': 0,
- }
- ],
- result,
- )
+ assert [
+ {
+ 'app_engine_http_request': {
+ 'body': '',
+ 'headers': {},
+ 'http_method': 0,
+ 'relative_uri': '',
+ },
+ 'dispatch_count': 0,
+ 'name': '',
+ 'response_count': 0,
+ 'view': 0,
+ }
+ ] == result
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=None,
@@ -361,7 +352,7 @@ def test_delete_task(self, mock_hook):
result = operator.execute(context=None)
- self.assertEqual(None, result)
+ assert result is None
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=None,
@@ -387,16 +378,13 @@ def test_run_task(self, mock_hook):
result = operator.execute(context=None)
- self.assertEqual(
- {
- 'app_engine_http_request': {'body': '', 'headers': {}, 'http_method': 0, 'relative_uri': ''},
- 'dispatch_count': 0,
- 'name': '',
- 'response_count': 0,
- 'view': 0,
- },
- result,
- )
+ assert {
+ 'app_engine_http_request': {'body': '', 'headers': {}, 'http_method': 0, 'relative_uri': ''},
+ 'dispatch_count': 0,
+ 'name': '',
+ 'response_count': 0,
+ 'view': 0,
+ } == result
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=None,
diff --git a/tests/providers/google/cloud/operators/test_text_to_speech.py b/tests/providers/google/cloud/operators/test_text_to_speech.py
index 006c6b5df77a6..6b6ac3135ca95 100644
--- a/tests/providers/google/cloud/operators/test_text_to_speech.py
+++ b/tests/providers/google/cloud/operators/test_text_to_speech.py
@@ -19,6 +19,7 @@
import unittest
from unittest.mock import ANY, Mock, PropertyMock, patch
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -93,7 +94,7 @@ def test_missing_arguments(
mock_text_to_speech_hook,
mock_gcp_hook,
):
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
CloudTextToSpeechSynthesizeOperator(
project_id="project-id",
input_data=input_data,
@@ -104,7 +105,7 @@ def test_missing_arguments(
task_id="id",
).execute(context={"task_instance": Mock()})
- err = e.exception
- self.assertIn(missing_arg, str(err))
+ err = ctx.value
+ assert missing_arg in str(err)
mock_text_to_speech_hook.assert_not_called()
mock_gcp_hook.assert_not_called()
diff --git a/tests/providers/google/cloud/operators/test_translate.py b/tests/providers/google/cloud/operators/test_translate.py
index c32bd2f709701..5df2abd41d2cd 100644
--- a/tests/providers/google/cloud/operators/test_translate.py
+++ b/tests/providers/google/cloud/operators/test_translate.py
@@ -58,14 +58,11 @@ def test_minimal_green_path(self, mock_hook):
source_language=None,
model='base',
)
- self.assertEqual(
- [
- {
- 'translatedText': 'Yellowing self Gęśle',
- 'detectedSourceLanguage': 'pl',
- 'model': 'base',
- 'input': 'zażółć gęślą jaźń',
- }
- ],
- return_value,
- )
+ assert [
+ {
+ 'translatedText': 'Yellowing self Gęśle',
+ 'detectedSourceLanguage': 'pl',
+ 'model': 'base',
+ 'input': 'zażółć gęślą jaźń',
+ }
+ ] == return_value
diff --git a/tests/providers/google/cloud/operators/test_translate_speech.py b/tests/providers/google/cloud/operators/test_translate_speech.py
index fc1c6376cadcc..ec170e2058860 100644
--- a/tests/providers/google/cloud/operators/test_translate_speech.py
+++ b/tests/providers/google/cloud/operators/test_translate_speech.py
@@ -19,6 +19,7 @@
import unittest
from unittest import mock
+import pytest
from google.cloud.speech_v1.proto.cloud_speech_pb2 import (
RecognizeResponse,
SpeechRecognitionAlternative,
@@ -86,17 +87,14 @@ def test_minimal_green_path(self, mock_translate_hook, mock_speech_hook):
source_language=None,
model='base',
)
- self.assertEqual(
- [
- {
- 'translatedText': 'sprawdzić wynik rozpoznawania mowy',
- 'detectedSourceLanguage': 'en',
- 'model': 'base',
- 'input': 'test speech recognition result',
- }
- ],
- return_value,
- )
+ assert [
+ {
+ 'translatedText': 'sprawdzić wynik rozpoznawania mowy',
+ 'detectedSourceLanguage': 'en',
+ 'model': 'base',
+ 'input': 'test speech recognition result',
+ }
+ ] == return_value
@mock.patch('airflow.providers.google.cloud.operators.translate_speech.CloudSpeechToTextHook')
@mock.patch('airflow.providers.google.cloud.operators.translate_speech.CloudTranslateHook')
@@ -114,10 +112,10 @@ def test_bad_recognition_response(self, mock_translate_hook, mock_speech_hook):
gcp_conn_id=GCP_CONN_ID,
task_id='id',
)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op.execute(context=None)
- err = cm.exception
- self.assertIn("it should contain 'alternatives' field", str(err))
+ err = ctx.value
+ assert "it should contain 'alternatives' field" in str(err)
mock_speech_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
diff --git a/tests/providers/google/cloud/operators/test_vision.py b/tests/providers/google/cloud/operators/test_vision.py
index 2ca8d9ae70b89..aebd3d655435a 100644
--- a/tests/providers/google/cloud/operators/test_vision.py
+++ b/tests/providers/google/cloud/operators/test_vision.py
@@ -92,7 +92,7 @@ def test_already_exists(self, mock_hook):
task_id='id',
)
result = op.execute(None)
- self.assertEqual(PRODUCTSET_ID_TEST, result)
+ assert PRODUCTSET_ID_TEST == result
class TestCloudVisionProductSetUpdate(unittest.TestCase):
@@ -195,7 +195,7 @@ def test_already_exists(self, mock_hook):
task_id='id',
)
result = op.execute(None)
- self.assertEqual(PRODUCT_ID_TEST, result)
+ assert PRODUCT_ID_TEST == result
class TestCloudVisionProductGet(unittest.TestCase):
diff --git a/tests/providers/google/cloud/secrets/test_secret_manager.py b/tests/providers/google/cloud/secrets/test_secret_manager.py
index 2b2cbf0d2ed30..12ddf051b556c 100644
--- a/tests/providers/google/cloud/secrets/test_secret_manager.py
+++ b/tests/providers/google/cloud/secrets/test_secret_manager.py
@@ -15,8 +15,10 @@
# specific language governing permissions and limitations
# under the License.
+import re
from unittest import TestCase, mock
+import pytest
from google.api_core.exceptions import NotFound
from google.cloud.secretmanager_v1.types import AccessSecretVersionResponse
from parameterized import parameterized
@@ -52,7 +54,7 @@ def test_default_valid_and_sep(self, mock_client_callable, mock_get_creds):
mock_client_callable.return_value = mock_client
backend = CloudSecretManagerBackend()
- self.assertTrue(backend._is_valid_prefix_and_sep())
+ assert backend._is_valid_prefix_and_sep()
@parameterized.expand(
[
@@ -63,7 +65,7 @@ def test_default_valid_and_sep(self, mock_client_callable, mock_get_creds):
]
)
def test_raise_exception_with_invalid_prefix_sep(self, _, prefix, sep):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
CloudSecretManagerBackend(connections_prefix=prefix, sep=sep)
@parameterized.expand(
@@ -85,7 +87,7 @@ def test_is_valid_prefix_and_sep(self, _, prefix, sep, is_valid, mock_client_cal
backend = CloudSecretManagerBackend()
backend.connections_prefix = prefix
backend.sep = sep
- self.assertEqual(backend._is_valid_prefix_and_sep(), is_valid)
+ assert backend._is_valid_prefix_and_sep() == is_valid
@parameterized.expand(["airflow-connections", "connections", "airflow"])
@mock.patch(MODULE_NAME + ".get_credentials_and_project_id")
@@ -102,7 +104,7 @@ def test_get_conn_uri(self, connections_prefix, mock_client_callable, mock_get_c
secrets_manager_backend = CloudSecretManagerBackend(connections_prefix=connections_prefix)
secret_id = secrets_manager_backend.build_path(connections_prefix, CONN_ID, SEP)
returned_uri = secrets_manager_backend.get_conn_uri(conn_id=CONN_ID)
- self.assertEqual(CONN_URI, returned_uri)
+ assert CONN_URI == returned_uri
mock_client.secret_version_path.assert_called_once_with(PROJECT_ID, secret_id, "latest")
@mock.patch(MODULE_NAME + ".get_credentials_and_project_id")
@@ -111,8 +113,8 @@ def test_get_connections(self, mock_get_uri, mock_get_creds):
mock_get_creds.return_value = CREDENTIALS, PROJECT_ID
mock_get_uri.return_value = CONN_URI
conns = CloudSecretManagerBackend().get_connections(conn_id=CONN_ID)
- self.assertIsInstance(conns, list)
- self.assertIsInstance(conns[0], Connection)
+ assert isinstance(conns, list)
+ assert isinstance(conns[0], Connection)
@mock.patch(MODULE_NAME + ".get_credentials_and_project_id")
@mock.patch(CLIENT_MODULE_NAME + ".SecretManagerServiceClient")
@@ -126,11 +128,11 @@ def test_get_conn_uri_non_existent_key(self, mock_client_callable, mock_get_cred
secrets_manager_backend = CloudSecretManagerBackend(connections_prefix=CONNECTIONS_PREFIX)
secret_id = secrets_manager_backend.build_path(CONNECTIONS_PREFIX, CONN_ID, SEP)
with self.assertLogs(secrets_manager_backend.client.log, level="ERROR") as log_output:
- self.assertIsNone(secrets_manager_backend.get_conn_uri(conn_id=CONN_ID))
- self.assertEqual([], secrets_manager_backend.get_connections(conn_id=CONN_ID))
- self.assertRegex(
- log_output.output[0],
+ assert secrets_manager_backend.get_conn_uri(conn_id=CONN_ID) is None
+ assert [] == secrets_manager_backend.get_connections(conn_id=CONN_ID)
+ assert re.search(
f"Google Cloud API Call Error \\(NotFound\\): Secret ID {secret_id} not found",
+ log_output.output[0],
)
@parameterized.expand(["airflow-variables", "variables", "airflow"])
@@ -148,7 +150,7 @@ def test_get_variable(self, variables_prefix, mock_client_callable, mock_get_cre
secrets_manager_backend = CloudSecretManagerBackend(variables_prefix=variables_prefix)
secret_id = secrets_manager_backend.build_path(variables_prefix, VAR_KEY, SEP)
returned_uri = secrets_manager_backend.get_variable(VAR_KEY)
- self.assertEqual(VAR_VALUE, returned_uri)
+ assert VAR_VALUE == returned_uri
mock_client.secret_version_path.assert_called_once_with(PROJECT_ID, secret_id, "latest")
@parameterized.expand(["airflow-config", "config", "airflow"])
@@ -166,7 +168,7 @@ def test_get_config(self, config_prefix, mock_client_callable, mock_get_creds):
secrets_manager_backend = CloudSecretManagerBackend(config_prefix=config_prefix)
secret_id = secrets_manager_backend.build_path(config_prefix, CONFIG_KEY, SEP)
returned_val = secrets_manager_backend.get_config(CONFIG_KEY)
- self.assertEqual(CONFIG_VALUE, returned_val)
+ assert CONFIG_VALUE == returned_val
mock_client.secret_version_path.assert_called_once_with(PROJECT_ID, secret_id, "latest")
@parameterized.expand(["airflow-variables", "variables", "airflow"])
@@ -186,7 +188,7 @@ def test_get_variable_override_project_id(self, variables_prefix, mock_client_ca
)
secret_id = secrets_manager_backend.build_path(variables_prefix, VAR_KEY, SEP)
returned_uri = secrets_manager_backend.get_variable(VAR_KEY)
- self.assertEqual(VAR_VALUE, returned_uri)
+ assert VAR_VALUE == returned_uri
mock_client.secret_version_path.assert_called_once_with(OVERRIDDEN_PROJECT_ID, secret_id, "latest")
@mock.patch(MODULE_NAME + ".get_credentials_and_project_id")
@@ -201,10 +203,10 @@ def test_get_variable_non_existent_key(self, mock_client_callable, mock_get_cred
secrets_manager_backend = CloudSecretManagerBackend(variables_prefix=VARIABLES_PREFIX)
secret_id = secrets_manager_backend.build_path(VARIABLES_PREFIX, VAR_KEY, SEP)
with self.assertLogs(secrets_manager_backend.client.log, level="ERROR") as log_output:
- self.assertIsNone(secrets_manager_backend.get_variable(VAR_KEY))
- self.assertRegex(
- log_output.output[0],
+ assert secrets_manager_backend.get_variable(VAR_KEY) is None
+ assert re.search(
f"Google Cloud API Call Error \\(NotFound\\): Secret ID {secret_id} not found",
+ log_output.output[0],
)
@mock.patch(MODULE_NAME + ".get_credentials_and_project_id")
@@ -221,7 +223,7 @@ def test_connections_prefix_none_value(self, mock_client_callable, mock_get_cred
secrets_manager_backend = CloudSecretManagerBackend(connections_prefix=None)
mock_is_valid_prefix_sep.assert_not_called()
- self.assertIsNone(secrets_manager_backend.get_conn_uri(conn_id=CONN_ID))
+ assert secrets_manager_backend.get_conn_uri(conn_id=CONN_ID) is None
mock_get_secret.assert_not_called()
@mock.patch(MODULE_NAME + ".get_credentials_and_project_id")
@@ -234,7 +236,7 @@ def test_variables_prefix_none_value(self, mock_client_callable, mock_get_creds)
with mock.patch(MODULE_NAME + '.CloudSecretManagerBackend._get_secret') as mock_get_secret:
secrets_manager_backend = CloudSecretManagerBackend(variables_prefix=None)
- self.assertIsNone(secrets_manager_backend.get_variable(VAR_KEY))
+ assert secrets_manager_backend.get_variable(VAR_KEY) is None
mock_get_secret.assert_not_called()
@mock.patch(MODULE_NAME + ".get_credentials_and_project_id")
@@ -247,5 +249,5 @@ def test_config_prefix_none_value(self, mock_client_callable, mock_get_creds):
with mock.patch(MODULE_NAME + '.CloudSecretManagerBackend._get_secret') as mock_get_secret:
secrets_manager_backend = CloudSecretManagerBackend(config_prefix=None)
- self.assertIsNone(secrets_manager_backend.get_config(CONFIG_KEY))
+ assert secrets_manager_backend.get_config(CONFIG_KEY) is None
mock_get_secret.assert_not_called()
diff --git a/tests/providers/google/cloud/secrets/test_secret_manager_system.py b/tests/providers/google/cloud/secrets/test_secret_manager_system.py
index 34293132d2ec6..ba9aab9320a1a 100644
--- a/tests/providers/google/cloud/secrets/test_secret_manager_system.py
+++ b/tests/providers/google/cloud/secrets/test_secret_manager_system.py
@@ -42,7 +42,7 @@ def test_should_read_secret_from_variable(self):
{self.secret_name} --data-file=- --replication-policy=automatic'
subprocess.run(["bash", "-c", cmd], check=True)
result = subprocess.check_output(['airflow', 'variables', 'get', self.name])
- self.assertIn("TEST_CONTENT", result.decode())
+ assert "TEST_CONTENT" in result.decode()
@provide_gcp_context(GCP_SECRET_MANAGER_KEY, project_id=GoogleSystemTest._project_id())
def tearDown(self) -> None:
@@ -63,7 +63,7 @@ def test_should_read_secret_from_variable(self):
{self.secret_name} --data-file=- --replication-policy=automatic'
subprocess.run(["bash", "-c", cmd], check=True)
result = subprocess.check_output(['airflow', 'connections', 'get', self.name])
- self.assertIn("URI: mysql://user:pass@example.org", result.decode())
+ assert "URI: mysql://user:pass@example.org" in result.decode()
@provide_gcp_context(GCP_SECRET_MANAGER_KEY, project_id=GoogleSystemTest._project_id())
def tearDown(self) -> None:
diff --git a/tests/providers/google/cloud/sensors/test_bigquery.py b/tests/providers/google/cloud/sensors/test_bigquery.py
index 52af97295390a..488224f6249de 100644
--- a/tests/providers/google/cloud/sensors/test_bigquery.py
+++ b/tests/providers/google/cloud/sensors/test_bigquery.py
@@ -46,7 +46,7 @@ def test_passing_arguments_to_hook(self, mock_hook):
mock_hook.return_value.table_exists.return_value = True
results = task.poke(mock.MagicMock())
- self.assertEqual(True, results)
+ assert results is True
mock_hook.assert_called_once_with(
bigquery_conn_id=TEST_GCP_CONN_ID,
@@ -74,7 +74,7 @@ def test_passing_arguments_to_hook(self, mock_hook):
mock_hook.return_value.table_partition_exists.return_value = True
results = task.poke(mock.MagicMock())
- self.assertEqual(True, results)
+ assert results is True
mock_hook.assert_called_once_with(
bigquery_conn_id=TEST_GCP_CONN_ID,
diff --git a/tests/providers/google/cloud/sensors/test_bigquery_dts.py b/tests/providers/google/cloud/sensors/test_bigquery_dts.py
index c8a05483f1ec6..df51a17b65ab0 100644
--- a/tests/providers/google/cloud/sensors/test_bigquery_dts.py
+++ b/tests/providers/google/cloud/sensors/test_bigquery_dts.py
@@ -43,7 +43,7 @@ def test_poke_returns_false(self, mock_hook):
)
result = op.poke({})
- self.assertEqual(result, False)
+ assert result is False
mock_hook.return_value.get_transfer_run.assert_called_once_with(
transfer_config_id=TRANSFER_CONFIG_ID,
run_id=RUN_ID,
@@ -67,7 +67,7 @@ def test_poke_returns_true(self, mock_hook):
)
result = op.poke({})
- self.assertEqual(result, True)
+ assert result is True
mock_hook.return_value.get_transfer_run.assert_called_once_with(
transfer_config_id=TRANSFER_CONFIG_ID,
run_id=RUN_ID,
diff --git a/tests/providers/google/cloud/sensors/test_bigtable.py b/tests/providers/google/cloud/sensors/test_bigtable.py
index e6df23c613e1b..7fe35da5b70c5 100644
--- a/tests/providers/google/cloud/sensors/test_bigtable.py
+++ b/tests/providers/google/cloud/sensors/test_bigtable.py
@@ -20,6 +20,7 @@
from unittest import mock
import google.api_core.exceptions
+import pytest
from google.cloud.bigtable.instance import Instance
from google.cloud.bigtable.table import ClusterState
from parameterized import parameterized
@@ -44,7 +45,7 @@ class BigtableWaitForTableReplicationTest(unittest.TestCase):
)
@mock.patch('airflow.providers.google.cloud.sensors.bigtable.BigtableHook')
def test_empty_attribute(self, missing_attribute, project_id, instance_id, table_id, mock_hook):
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
BigtableTableReplicationCompletedSensor(
project_id=project_id,
instance_id=instance_id,
@@ -53,8 +54,8 @@ def test_empty_attribute(self, missing_attribute, project_id, instance_id, table
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
- err = e.exception
- self.assertEqual(str(err), f'Empty parameter: {missing_attribute}')
+ err = ctx.value
+ assert str(err) == f'Empty parameter: {missing_attribute}'
mock_hook.assert_not_called()
@mock.patch('airflow.providers.google.cloud.sensors.bigtable.BigtableHook')
@@ -69,7 +70,7 @@ def test_wait_no_instance(self, mock_hook):
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -90,7 +91,7 @@ def test_wait_no_table(self, mock_hook):
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -108,7 +109,7 @@ def test_wait_not_ready(self, mock_hook):
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -126,7 +127,7 @@ def test_wait_ready(self, mock_hook):
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
- self.assertTrue(op.poke(None))
+ assert op.poke(None)
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
diff --git a/tests/providers/google/cloud/sensors/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/sensors/test_cloud_storage_transfer_service.py
index aa169fdf1dd23..d8027051896aa 100644
--- a/tests/providers/google/cloud/sensors/test_cloud_storage_transfer_service.py
+++ b/tests/providers/google/cloud/sensors/test_cloud_storage_transfer_service.py
@@ -65,7 +65,7 @@ def test_wait_for_status_success(self, mock_tool):
mock_tool.operations_contain_expected_statuses.assert_called_once_with(
operations=operations, expected_statuses={GcpTransferOperationStatus.SUCCESS}
)
- self.assertTrue(result)
+ assert result
@mock.patch(
'airflow.providers.google.cloud.sensors.cloud_storage_transfer_service.CloudDataTransferServiceHook'
@@ -86,7 +86,7 @@ def test_wait_for_status_success_default_expected_status(self, mock_tool):
mock_tool.operations_contain_expected_statuses.assert_called_once_with(
operations=mock.ANY, expected_statuses={GcpTransferOperationStatus.SUCCESS}
)
- self.assertTrue(result)
+ assert result
@mock.patch(
'airflow.providers.google.cloud.sensors.cloud_storage_transfer_service.CloudDataTransferServiceHook'
@@ -126,7 +126,7 @@ def test_wait_for_status_after_retry(self, mock_tool):
context = {'ti': (mock.Mock(**{'xcom_push.return_value': None}))}
result = op.poke(context)
- self.assertFalse(result)
+ assert not result
mock_tool.operations_contain_expected_statuses.assert_called_once_with(
operations=operations_set[0], expected_statuses={GcpTransferOperationStatus.SUCCESS}
@@ -134,7 +134,7 @@ def test_wait_for_status_after_retry(self, mock_tool):
mock_tool.operations_contain_expected_statuses.reset_mock()
result = op.poke(context)
- self.assertTrue(result)
+ assert result
mock_tool.operations_contain_expected_statuses.assert_called_once_with(
operations=operations_set[1], expected_statuses={GcpTransferOperationStatus.SUCCESS}
@@ -177,7 +177,7 @@ def test_wait_for_status_normalize_status(self, expected_status, received_status
context = {'ti': (mock.Mock(**{'xcom_push.return_value': None}))}
result = op.poke(context)
- self.assertFalse(result)
+ assert not result
mock_tool.operations_contain_expected_statuses.assert_called_once_with(
operations=operations, expected_statuses=received_status
diff --git a/tests/providers/google/cloud/sensors/test_dataflow.py b/tests/providers/google/cloud/sensors/test_dataflow.py
index 54561e302f6ea..9c8b158deec18 100644
--- a/tests/providers/google/cloud/sensors/test_dataflow.py
+++ b/tests/providers/google/cloud/sensors/test_dataflow.py
@@ -19,6 +19,7 @@
import unittest
from unittest import mock
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -62,7 +63,7 @@ def test_poke(self, expected_status, current_status, sensor_return, mock_hook):
mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": current_status}
results = task.poke(mock.MagicMock())
- self.assertEqual(sensor_return, results)
+ assert sensor_return == results
mock_hook.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
@@ -88,9 +89,9 @@ def test_poke_raise_exception(self, mock_hook):
)
mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": DataflowJobStatus.JOB_STATE_CANCELLED}
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowException,
- f"Job with id '{TEST_JOB_ID}' is already in terminal state: "
+ match=f"Job with id '{TEST_JOB_ID}' is already in terminal state: "
f"{DataflowJobStatus.JOB_STATE_CANCELLED}",
):
task.poke(mock.MagicMock())
@@ -133,7 +134,7 @@ def test_poke(self, job_current_state, fail_on_terminal_state, mock_hook):
mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": job_current_state}
results = task.poke(mock.MagicMock())
- self.assertEqual(callback.return_value, results)
+ assert callback.return_value == results
mock_hook.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
@@ -176,7 +177,7 @@ def test_poke(self, job_current_state, fail_on_terminal_state, mock_hook):
results = task.poke(mock.MagicMock())
- self.assertEqual(callback.return_value, results)
+ assert callback.return_value == results
mock_hook.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
@@ -207,9 +208,9 @@ def test_poke_raise_exception(self, mock_hook):
)
mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": DataflowJobStatus.JOB_STATE_DONE}
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowException,
- f"Job with id '{TEST_JOB_ID}' is already in terminal state: "
+ match=f"Job with id '{TEST_JOB_ID}' is already in terminal state: "
f"{DataflowJobStatus.JOB_STATE_DONE}",
):
task.poke(mock.MagicMock())
@@ -252,7 +253,7 @@ def test_poke(self, job_current_state, fail_on_terminal_state, mock_hook):
results = task.poke(mock.MagicMock())
- self.assertEqual(callback.return_value, results)
+ assert callback.return_value == results
mock_hook.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
@@ -283,9 +284,9 @@ def test_poke_raise_exception_on_terminal_state(self, mock_hook):
)
mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": DataflowJobStatus.JOB_STATE_DONE}
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowException,
- f"Job with id '{TEST_JOB_ID}' is already in terminal state: "
+ match=f"Job with id '{TEST_JOB_ID}' is already in terminal state: "
f"{DataflowJobStatus.JOB_STATE_DONE}",
):
task.poke(mock.MagicMock())
diff --git a/tests/providers/google/cloud/sensors/test_dataproc.py b/tests/providers/google/cloud/sensors/test_dataproc.py
index f2a1d45704cc8..1ce8eea3fcb75 100644
--- a/tests/providers/google/cloud/sensors/test_dataproc.py
+++ b/tests/providers/google/cloud/sensors/test_dataproc.py
@@ -18,6 +18,7 @@
import unittest
from unittest import mock
+import pytest
from google.cloud.dataproc_v1beta2.types import JobStatus
from airflow import AirflowException
@@ -61,7 +62,7 @@ def test_done(self, mock_hook):
mock_hook.return_value.get_job.assert_called_once_with(
job_id=job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
)
- self.assertTrue(ret)
+ assert ret
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_error(self, mock_hook):
@@ -78,7 +79,7 @@ def test_error(self, mock_hook):
timeout=TIMEOUT,
)
- with self.assertRaisesRegex(AirflowException, "Job failed"):
+ with pytest.raises(AirflowException, match="Job failed"):
sensor.poke(context={})
mock_hook.return_value.get_job.assert_called_once_with(
@@ -104,7 +105,7 @@ def test_wait(self, mock_hook):
mock_hook.return_value.get_job.assert_called_once_with(
job_id=job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
)
- self.assertFalse(ret)
+ assert not ret
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_cancelled(self, mock_hook):
@@ -120,7 +121,7 @@ def test_cancelled(self, mock_hook):
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
)
- with self.assertRaisesRegex(AirflowException, "Job was cancelled"):
+ with pytest.raises(AirflowException, match="Job was cancelled"):
sensor.poke(context={})
mock_hook.return_value.get_job.assert_called_once_with(
diff --git a/tests/providers/google/cloud/sensors/test_gcs.py b/tests/providers/google/cloud/sensors/test_gcs.py
index 2bd9f60c0cd44..e3c89174cfbcf 100644
--- a/tests/providers/google/cloud/sensors/test_gcs.py
+++ b/tests/providers/google/cloud/sensors/test_gcs.py
@@ -19,6 +19,7 @@
from unittest import TestCase, mock
import pendulum
+import pytest
from airflow.exceptions import AirflowSensorTimeout
from airflow.models.dag import DAG, AirflowException
@@ -75,7 +76,7 @@ def test_should_pass_argument_to_hook(self, mock_hook):
result = task.poke(mock.MagicMock())
- self.assertEqual(True, result)
+ assert result is True
mock_hook.assert_called_once_with(
delegate_to=TEST_DELEGATE_TO,
google_cloud_storage_conn_id=TEST_GCP_CONN_ID,
@@ -91,7 +92,7 @@ def test_should_support_datetime(self):
'execution_date': datetime(2019, 2, 14, 0, 0),
}
result = ts_function(context)
- self.assertEqual(datetime(2019, 2, 19, 0, 0, tzinfo=timezone.utc), result)
+ assert datetime(2019, 2, 19, 0, 0, tzinfo=timezone.utc) == result
def test_should_support_cron(self):
dag = DAG(dag_id=TEST_DAG_ID, start_date=datetime(2019, 2, 19, 0, 0), schedule_interval='@weekly')
@@ -101,7 +102,7 @@ def test_should_support_cron(self):
'execution_date': datetime(2019, 2, 19),
}
result = ts_function(context)
- self.assertEqual(pendulum.instance(datetime(2019, 2, 24)).isoformat(), result.isoformat())
+ assert pendulum.instance(datetime(2019, 2, 24)).isoformat() == result.isoformat()
class TestGoogleCloudStorageObjectUpdatedSensor(TestCase):
@@ -124,7 +125,7 @@ def test_should_pass_argument_to_hook(self, mock_hook):
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
mock_hook.return_value.is_updated_after.assert_called_once_with(TEST_BUCKET, TEST_OBJECT, mock.ANY)
- self.assertEqual(True, result)
+ assert result is True
class TestGoogleCloudStoragePrefixSensor(TestCase):
@@ -147,7 +148,7 @@ def test_should_pass_arguments_to_hook(self, mock_hook):
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, prefix=TEST_PREFIX)
- self.assertEqual(True, result)
+ assert result is True
@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook")
def test_should_return_false_on_empty_list(self, mock_hook):
@@ -161,7 +162,7 @@ def test_should_return_false_on_empty_list(self, mock_hook):
mock_hook.return_value.list.return_value = []
result = task.poke(mock.MagicMock)
- self.assertEqual(False, result)
+ assert result is False
@mock.patch('airflow.providers.google.cloud.sensors.gcs.GCSHook')
def test_execute(self, mock_hook):
@@ -185,7 +186,7 @@ def test_execute(self, mock_hook):
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, prefix=TEST_PREFIX)
- self.assertEqual(response, generated_messages)
+ assert response == generated_messages
@mock.patch('airflow.providers.google.cloud.sensors.gcs.GCSHook')
def test_execute_timeout(self, mock_hook):
@@ -193,7 +194,7 @@ def test_execute_timeout(self, mock_hook):
task_id="task-id", bucket=TEST_BUCKET, prefix=TEST_PREFIX, poke_interval=0, timeout=1
)
mock_hook.return_value.list.return_value = []
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task.execute(mock.MagicMock)
mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, prefix=TEST_PREFIX)
@@ -232,12 +233,12 @@ def test_get_gcs_hook(self, mock_hook):
delegate_to=TEST_DELEGATE_TO,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
- self.assertEqual(mock_hook.return_value, self.sensor.hook)
+ assert mock_hook.return_value == self.sensor.hook
@mock.patch('airflow.providers.google.cloud.sensors.gcs.get_time', mock_time)
def test_files_deleted_between_pokes_throw_error(self):
self.sensor.is_bucket_updated({'a', 'b'})
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.sensor.is_bucket_updated({'a'})
@mock.patch('airflow.providers.google.cloud.sensors.gcs.get_time', mock_time)
@@ -253,46 +254,46 @@ def test_files_deleted_between_pokes_allow_delete(self):
dag=self.dag,
)
self.sensor.is_bucket_updated({'a', 'b'})
- self.assertEqual(self.sensor.inactivity_seconds, 0)
+ assert self.sensor.inactivity_seconds == 0
self.sensor.is_bucket_updated({'a'})
- self.assertEqual(len(self.sensor.previous_objects), 1)
- self.assertEqual(self.sensor.inactivity_seconds, 0)
+ assert len(self.sensor.previous_objects) == 1
+ assert self.sensor.inactivity_seconds == 0
self.sensor.is_bucket_updated({'a', 'c'})
- self.assertEqual(self.sensor.inactivity_seconds, 0)
+ assert self.sensor.inactivity_seconds == 0
self.sensor.is_bucket_updated({'a', 'd'})
- self.assertEqual(self.sensor.inactivity_seconds, 0)
+ assert self.sensor.inactivity_seconds == 0
self.sensor.is_bucket_updated({'a', 'd'})
- self.assertEqual(self.sensor.inactivity_seconds, 10)
- self.assertTrue(self.sensor.is_bucket_updated({'a', 'd'}))
+ assert self.sensor.inactivity_seconds == 10
+ assert self.sensor.is_bucket_updated({'a', 'd'})
@mock.patch('airflow.providers.google.cloud.sensors.gcs.get_time', mock_time)
def test_incoming_data(self):
self.sensor.is_bucket_updated({'a'})
- self.assertEqual(self.sensor.inactivity_seconds, 0)
+ assert self.sensor.inactivity_seconds == 0
self.sensor.is_bucket_updated({'a', 'b'})
- self.assertEqual(self.sensor.inactivity_seconds, 0)
+ assert self.sensor.inactivity_seconds == 0
self.sensor.is_bucket_updated({'a', 'b', 'c'})
- self.assertEqual(self.sensor.inactivity_seconds, 0)
+ assert self.sensor.inactivity_seconds == 0
@mock.patch('airflow.providers.google.cloud.sensors.gcs.get_time', mock_time)
def test_no_new_data(self):
self.sensor.is_bucket_updated({'a'})
- self.assertEqual(self.sensor.inactivity_seconds, 0)
+ assert self.sensor.inactivity_seconds == 0
self.sensor.is_bucket_updated({'a'})
- self.assertEqual(self.sensor.inactivity_seconds, 10)
+ assert self.sensor.inactivity_seconds == 10
@mock.patch('airflow.providers.google.cloud.sensors.gcs.get_time', mock_time)
def test_no_new_data_success_criteria(self):
self.sensor.is_bucket_updated({'a'})
- self.assertEqual(self.sensor.inactivity_seconds, 0)
+ assert self.sensor.inactivity_seconds == 0
self.sensor.is_bucket_updated({'a'})
- self.assertEqual(self.sensor.inactivity_seconds, 10)
- self.assertTrue(self.sensor.is_bucket_updated({'a'}))
+ assert self.sensor.inactivity_seconds == 10
+ assert self.sensor.is_bucket_updated({'a'})
@mock.patch('airflow.providers.google.cloud.sensors.gcs.get_time', mock_time)
def test_not_enough_objects(self):
self.sensor.is_bucket_updated(set())
- self.assertEqual(self.sensor.inactivity_seconds, 0)
+ assert self.sensor.inactivity_seconds == 0
self.sensor.is_bucket_updated(set())
- self.assertEqual(self.sensor.inactivity_seconds, 10)
- self.assertFalse(self.sensor.is_bucket_updated(set()))
+ assert self.sensor.inactivity_seconds == 10
+ assert not self.sensor.is_bucket_updated(set())
diff --git a/tests/providers/google/cloud/sensors/test_pubsub.py b/tests/providers/google/cloud/sensors/test_pubsub.py
index 3124474ef9ae8..795860b3d353c 100644
--- a/tests/providers/google/cloud/sensors/test_pubsub.py
+++ b/tests/providers/google/cloud/sensors/test_pubsub.py
@@ -20,6 +20,7 @@
from typing import Any, Dict, List
from unittest import mock
+import pytest
from google.cloud.pubsub_v1.types import ReceivedMessage
from airflow.exceptions import AirflowSensorTimeout
@@ -55,7 +56,7 @@ def test_poke_no_messages(self, mock_hook):
)
mock_hook.return_value.pull.return_value = []
- self.assertEqual(False, operator.poke({}))
+ assert operator.poke({}) is False
@mock.patch('airflow.providers.google.cloud.sensors.pubsub.PubSubHook')
def test_poke_with_ack_messages(self, mock_hook):
@@ -70,7 +71,7 @@ def test_poke_with_ack_messages(self, mock_hook):
mock_hook.return_value.pull.return_value = generated_messages
- self.assertEqual(True, operator.poke({}))
+ assert operator.poke({}) is True
mock_hook.return_value.acknowledge.assert_called_once_with(
project_id=TEST_PROJECT,
subscription=TEST_SUBSCRIPTION,
@@ -94,7 +95,7 @@ def test_execute(self, mock_hook):
mock_hook.return_value.pull.assert_called_once_with(
project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=5, return_immediately=True
)
- self.assertEqual(generated_dicts, response)
+ assert generated_dicts == response
@mock.patch('airflow.providers.google.cloud.sensors.pubsub.PubSubHook')
def test_execute_timeout(self, mock_hook):
@@ -108,7 +109,7 @@ def test_execute_timeout(self, mock_hook):
mock_hook.return_value.pull.return_value = []
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
operator.execute({})
mock_hook.return_value.pull.assert_called_once_with(
project_id=TEST_PROJECT,
diff --git a/tests/providers/google/cloud/transfers/test_adls_to_gcs.py b/tests/providers/google/cloud/transfers/test_adls_to_gcs.py
index 68649c242c6d1..9e7657f08adac 100644
--- a/tests/providers/google/cloud/transfers/test_adls_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_adls_to_gcs.py
@@ -49,12 +49,12 @@ def test_init(self):
gcp_conn_id=GCS_CONN_ID,
)
- self.assertEqual(operator.task_id, TASK_ID)
- self.assertEqual(operator.src_adls, ADLS_PATH_1)
- self.assertEqual(operator.dest_gcs, GCS_PATH)
- self.assertEqual(operator.replace, False)
- self.assertEqual(operator.gcp_conn_id, GCS_CONN_ID)
- self.assertEqual(operator.azure_data_lake_conn_id, AZURE_CONN_ID)
+ assert operator.task_id == TASK_ID
+ assert operator.src_adls == ADLS_PATH_1
+ assert operator.dest_gcs == GCS_PATH
+ assert operator.replace is False
+ assert operator.gcp_conn_id == GCS_CONN_ID
+ assert operator.azure_data_lake_conn_id == AZURE_CONN_ID
@mock.patch('airflow.providers.google.cloud.transfers.adls_to_gcs.AzureDataLakeHook')
@mock.patch('airflow.providers.microsoft.azure.operators.adls_list.AzureDataLakeHook')
@@ -101,7 +101,7 @@ def test_execute(self, gcs_mock_hook, adls_one_mock_hook, adls_two_mock_hook):
)
# we expect MOCK_FILES to be uploaded
- self.assertEqual(sorted(MOCK_FILES), sorted(uploaded_files))
+ assert sorted(MOCK_FILES) == sorted(uploaded_files)
@mock.patch('airflow.providers.google.cloud.transfers.adls_to_gcs.AzureDataLakeHook')
@mock.patch('airflow.providers.microsoft.azure.operators.adls_list.AzureDataLakeHook')
@@ -140,4 +140,4 @@ def test_execute_with_gzip(self, gcs_mock_hook, adls_one_mock_hook, adls_two_moc
)
# we expect MOCK_FILES to be uploaded
- self.assertEqual(sorted(MOCK_FILES), sorted(uploaded_files))
+ assert sorted(MOCK_FILES) == sorted(uploaded_files)
diff --git a/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs.py b/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs.py
index 1a46622869b27..424baa1713b40 100644
--- a/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_azure_fileshare_to_gcs.py
@@ -44,13 +44,13 @@ def test_init(self):
google_impersonation_chain=IMPERSONATION_CHAIN,
)
- self.assertEqual(operator.task_id, TASK_ID)
- self.assertEqual(operator.share_name, AZURE_FILESHARE_SHARE)
- self.assertEqual(operator.directory_name, AZURE_FILESHARE_DIRECTORY_NAME)
- self.assertEqual(operator.wasb_conn_id, WASB_CONN_ID)
- self.assertEqual(operator.gcp_conn_id, GCS_CONN_ID)
- self.assertEqual(operator.dest_gcs, GCS_PATH_PREFIX)
- self.assertEqual(operator.google_impersonation_chain, IMPERSONATION_CHAIN)
+ assert operator.task_id == TASK_ID
+ assert operator.share_name == AZURE_FILESHARE_SHARE
+ assert operator.directory_name == AZURE_FILESHARE_DIRECTORY_NAME
+ assert operator.wasb_conn_id == WASB_CONN_ID
+ assert operator.gcp_conn_id == GCS_CONN_ID
+ assert operator.dest_gcs == GCS_PATH_PREFIX
+ assert operator.google_impersonation_chain == IMPERSONATION_CHAIN
@mock.patch('airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs.AzureFileShareHook')
@mock.patch('airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs.GCSHook')
@@ -88,7 +88,7 @@ def test_execute(self, gcs_mock_hook, azure_fileshare_mock_hook):
impersonation_chain=IMPERSONATION_CHAIN,
)
- self.assertEqual(sorted(MOCK_FILES), sorted(uploaded_files))
+ assert sorted(MOCK_FILES) == sorted(uploaded_files)
@mock.patch('airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs.AzureFileShareHook')
@mock.patch('airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs.GCSHook')
diff --git a/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py b/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py
index 4d6481d1b1e72..e71745d112a1e 100644
--- a/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py
@@ -65,51 +65,52 @@ def test_execute(self, mock_hook, mock_upload, mock_tempfile):
def test_convert_value(self):
op = CassandraToGCSOperator
- self.assertEqual(op.convert_value(None), None)
- self.assertEqual(op.convert_value(1), 1)
- self.assertEqual(op.convert_value(1.0), 1.0)
- self.assertEqual(op.convert_value("text"), "text")
- self.assertEqual(op.convert_value(True), True)
- self.assertEqual(op.convert_value({"a": "b"}), {"a": "b"})
+ assert op.convert_value(None) is None
+ assert op.convert_value(1) == 1
+ assert op.convert_value(1.0) == 1.0
+ assert op.convert_value("text") == "text"
+ assert op.convert_value(True) is True
+ assert op.convert_value({"a": "b"}) == {"a": "b"}
from datetime import datetime
now = datetime.now()
- self.assertEqual(op.convert_value(now), str(now))
+ assert op.convert_value(now) == str(now)
from cassandra.util import Date
date_str = "2018-01-01"
date = Date(date_str)
- self.assertEqual(op.convert_value(date), str(date_str))
+ assert op.convert_value(date) == str(date_str)
import uuid
from base64 import b64encode
test_uuid = uuid.uuid4()
encoded_uuid = b64encode(test_uuid.bytes).decode("ascii")
- self.assertEqual(op.convert_value(test_uuid), encoded_uuid)
+ assert op.convert_value(test_uuid) == encoded_uuid
byte_str = b"abc"
encoded_b = b64encode(byte_str).decode("ascii")
- self.assertEqual(op.convert_value(byte_str), encoded_b)
+ assert op.convert_value(byte_str) == encoded_b
from decimal import Decimal
decimal = Decimal(1.0)
- self.assertEqual(op.convert_value(decimal), float(decimal))
+ assert op.convert_value(decimal) == float(decimal)
from cassandra.util import Time
time = Time(0)
- self.assertEqual(op.convert_value(time), "00:00:00")
+ assert op.convert_value(time) == "00:00:00"
date_str_lst = ["2018-01-01", "2018-01-02", "2018-01-03"]
date_lst = [Date(d) for d in date_str_lst]
- self.assertEqual(op.convert_value(date_lst), date_str_lst)
+ assert op.convert_value(date_lst) == date_str_lst
date_tpl = tuple(date_lst)
- self.assertEqual(
- op.convert_value(date_tpl),
- {"field_0": "2018-01-01", "field_1": "2018-01-02", "field_2": "2018-01-03"},
- )
+ assert op.convert_value(date_tpl) == {
+ "field_0": "2018-01-01",
+ "field_1": "2018-01-02",
+ "field_2": "2018-01-03",
+ }
diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py
index b1a0188681f9b..5c0c38cb6c9cd 100644
--- a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py
@@ -20,6 +20,8 @@
from datetime import datetime
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.transfers.gcs_to_gcs import WILDCARD, GCSToGCSOperator
@@ -381,7 +383,7 @@ def test_execute_more_than_1_wildcard(self, mock_hook):
total_wildcards
)
- with self.assertRaisesRegex(AirflowException, error_msg):
+ with pytest.raises(AirflowException, match=error_msg):
operator.execute(None)
@mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook')
@@ -400,7 +402,7 @@ def test_execute_with_empty_destination_bucket(self, mock_hook):
mock_warn.assert_called_once_with(
'destination_bucket is None. Defaulting it to source_bucket (%s)', TEST_BUCKET
)
- self.assertEqual(operator.destination_bucket, operator.source_bucket)
+ assert operator.destination_bucket == operator.source_bucket
# Tests the use of delimiter and source object as list
@mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook')
@@ -419,9 +421,7 @@ def test_raises_exception_with_two_empty_list_inside_source_objects(self, mock_h
task_id=TASK_ID, source_bucket=TEST_BUCKET, source_objects=SOURCE_OBJECTS_TWO_EMPTY_STRING
)
- with self.assertRaisesRegex(
- AirflowException, "You can't have two empty strings inside source_object"
- ):
+ with pytest.raises(AirflowException, match="You can't have two empty strings inside source_object"):
operator.execute(None)
@mock.patch('airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook')
diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py b/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py
index 3cf3d537392f9..d1b0f626ab0d2 100644
--- a/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py
+++ b/tests/providers/google/cloud/transfers/test_gcs_to_sftp.py
@@ -21,6 +21,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.transfers.gcs_to_sftp import GCSToSFTPOperator
@@ -69,11 +71,11 @@ def test_execute_copy_single_file(self, sftp_hook, gcs_hook):
sftp_hook.assert_called_once_with(SFTP_CONN_ID)
args, kwargs = gcs_hook.return_value.download.call_args
- self.assertEqual(kwargs["bucket_name"], TEST_BUCKET)
- self.assertEqual(kwargs["object_name"], SOURCE_OBJECT_NO_WILDCARD)
+ assert kwargs["bucket_name"] == TEST_BUCKET
+ assert kwargs["object_name"] == SOURCE_OBJECT_NO_WILDCARD
args, kwargs = sftp_hook.return_value.store_file.call_args
- self.assertEqual(args[0], os.path.join(DESTINATION_SFTP, SOURCE_OBJECT_NO_WILDCARD))
+ assert args[0] == os.path.join(DESTINATION_SFTP, SOURCE_OBJECT_NO_WILDCARD)
gcs_hook.return_value.delete.assert_not_called()
@@ -100,11 +102,11 @@ def test_execute_move_single_file(self, sftp_hook, gcs_hook):
sftp_hook.assert_called_once_with(SFTP_CONN_ID)
args, kwargs = gcs_hook.return_value.download.call_args
- self.assertEqual(kwargs["bucket_name"], TEST_BUCKET)
- self.assertEqual(kwargs["object_name"], SOURCE_OBJECT_NO_WILDCARD)
+ assert kwargs["bucket_name"] == TEST_BUCKET
+ assert kwargs["object_name"] == SOURCE_OBJECT_NO_WILDCARD
args, kwargs = sftp_hook.return_value.store_file.call_args
- self.assertEqual(args[0], os.path.join(DESTINATION_SFTP, SOURCE_OBJECT_NO_WILDCARD))
+ assert args[0] == os.path.join(DESTINATION_SFTP, SOURCE_OBJECT_NO_WILDCARD)
gcs_hook.return_value.delete.assert_called_once_with(TEST_BUCKET, SOURCE_OBJECT_NO_WILDCARD)
@@ -127,11 +129,11 @@ def test_execute_copy_with_wildcard(self, sftp_hook, gcs_hook):
gcs_hook.return_value.list.assert_called_with(TEST_BUCKET, delimiter=".txt", prefix="test_object")
call_one, call_two = gcs_hook.return_value.download.call_args_list
- self.assertEqual(call_one[1]["bucket_name"], TEST_BUCKET)
- self.assertEqual(call_one[1]["object_name"], "test_object/file1.txt")
+ assert call_one[1]["bucket_name"] == TEST_BUCKET
+ assert call_one[1]["object_name"] == "test_object/file1.txt"
- self.assertEqual(call_two[1]["bucket_name"], TEST_BUCKET)
- self.assertEqual(call_two[1]["object_name"], "test_object/file2.txt")
+ assert call_two[1]["bucket_name"] == TEST_BUCKET
+ assert call_two[1]["object_name"] == "test_object/file2.txt"
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_sftp.GCSHook")
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_sftp.SFTPHook")
@@ -152,8 +154,8 @@ def test_execute_move_with_wildcard(self, sftp_hook, gcs_hook):
gcs_hook.return_value.list.assert_called_with(TEST_BUCKET, delimiter=".txt", prefix="test_object")
call_one, call_two = gcs_hook.return_value.delete.call_args_list
- self.assertEqual(call_one[0], (TEST_BUCKET, "test_object/file1.txt"))
- self.assertEqual(call_two[0], (TEST_BUCKET, "test_object/file2.txt"))
+ assert call_one[0] == (TEST_BUCKET, "test_object/file1.txt")
+ assert call_two[0] == (TEST_BUCKET, "test_object/file2.txt")
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_sftp.GCSHook")
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_sftp.SFTPHook")
@@ -169,5 +171,5 @@ def test_execute_more_than_one_wildcard_exception(self, sftp_hook, gcs_hook):
sftp_conn_id=SFTP_CONN_ID,
delegate_to=DELEGATE_TO,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
operator.execute(None)
diff --git a/tests/providers/google/cloud/transfers/test_local_to_gcs.py b/tests/providers/google/cloud/transfers/test_local_to_gcs.py
index 800d9d8a23d35..e7efb905f0fba 100644
--- a/tests/providers/google/cloud/transfers/test_local_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_local_to_gcs.py
@@ -56,11 +56,11 @@ def test_init(self):
dst='test/test1.csv',
**self._config,
)
- self.assertEqual(operator.src, self.testfile1)
- self.assertEqual(operator.dst, 'test/test1.csv')
- self.assertEqual(operator.bucket, self._config['bucket'])
- self.assertEqual(operator.mime_type, self._config['mime_type'])
- self.assertEqual(operator.gzip, self._config['gzip'])
+ assert operator.src == self.testfile1
+ assert operator.dst == 'test/test1.csv'
+ assert operator.bucket == self._config['bucket']
+ assert operator.mime_type == self._config['mime_type']
+ assert operator.gzip == self._config['gzip']
@mock.patch('airflow.providers.google.cloud.transfers.local_to_gcs.GCSHook', autospec=True)
def test_execute(self, mock_hook):
diff --git a/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py b/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py
index 8f22ef4e26d88..d9978d813aaeb 100644
--- a/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py
@@ -53,10 +53,10 @@ class TestMsSqlToGoogleCloudStorageOperator(unittest.TestCase):
def test_init(self):
"""Test MySqlToGoogleCloudStorageOperator instance is properly initialized."""
op = MSSQLToGCSOperator(task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME)
- self.assertEqual(op.task_id, TASK_ID)
- self.assertEqual(op.sql, SQL)
- self.assertEqual(op.bucket, BUCKET)
- self.assertEqual(op.filename, JSON_FILENAME)
+ assert op.task_id == TASK_ID
+ assert op.sql == SQL
+ assert op.bucket == BUCKET
+ assert op.filename == JSON_FILENAME
@mock.patch('airflow.providers.google.cloud.transfers.mssql_to_gcs.MsSqlHook')
@mock.patch('airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook')
@@ -73,12 +73,12 @@ def test_exec_success_json(self, gcs_hook_mock_class, mssql_hook_mock_class):
gcs_hook_mock = gcs_hook_mock_class.return_value
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual(JSON_FILENAME.format(0), obj)
- self.assertEqual('application/json', mime_type)
- self.assertEqual(GZIP, gzip)
+ assert BUCKET == bucket
+ assert JSON_FILENAME.format(0) == obj
+ assert 'application/json' == mime_type
+ assert GZIP == gzip
with open(tmp_filename, 'rb') as file:
- self.assertEqual(b''.join(NDJSON_LINES), file.read())
+ assert b''.join(NDJSON_LINES) == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -102,11 +102,11 @@ def test_file_splitting(self, gcs_hook_mock_class, mssql_hook_mock_class):
}
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual('application/json', mime_type)
- self.assertEqual(GZIP, gzip)
+ assert BUCKET == bucket
+ assert 'application/json' == mime_type
+ assert GZIP == gzip
with open(tmp_filename, 'rb') as file:
- self.assertEqual(expected_upload[obj], file.read())
+ assert expected_upload[obj] == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -132,7 +132,7 @@ def test_schema_file(self, gcs_hook_mock_class, mssql_hook_mock_class):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disable=unused-argument
if obj == SCHEMA_FILENAME:
with open(tmp_filename, 'rb') as file:
- self.assertEqual(b''.join(SCHEMA_JSON), file.read())
+ assert b''.join(SCHEMA_JSON) == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -142,4 +142,4 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disab
op.execute(None)
# once for the file and once for the schema
- self.assertEqual(2, gcs_hook_mock.upload.call_count)
+ assert 2 == gcs_hook_mock.upload.call_count
diff --git a/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py b/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py
index d5690c63afa7d..1e18c201f112e 100644
--- a/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py
@@ -21,6 +21,7 @@
import unittest
from unittest import mock
+import pytest
from _mysql_exceptions import ProgrammingError
from parameterized import parameterized
@@ -77,12 +78,12 @@ def test_init(self):
export_format='CSV',
field_delimiter='|',
)
- self.assertEqual(op.task_id, TASK_ID)
- self.assertEqual(op.sql, SQL)
- self.assertEqual(op.bucket, BUCKET)
- self.assertEqual(op.filename, JSON_FILENAME)
- self.assertEqual(op.export_format, 'csv')
- self.assertEqual(op.field_delimiter, '|')
+ assert op.task_id == TASK_ID
+ assert op.sql == SQL
+ assert op.bucket == BUCKET
+ assert op.filename == JSON_FILENAME
+ assert op.export_format == 'csv'
+ assert op.field_delimiter == '|'
@parameterized.expand(
[
@@ -100,7 +101,7 @@ def test_convert_type(self, value, schema_type, expected):
op = MySQLToGCSOperator(
task_id=TASK_ID, mysql_conn_id=MYSQL_CONN_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME
)
- self.assertEqual(op.convert_type(value, schema_type), expected)
+ assert op.convert_type(value, schema_type) == expected
@mock.patch('airflow.providers.google.cloud.transfers.mysql_to_gcs.MySqlHook')
@mock.patch('airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook')
@@ -117,12 +118,12 @@ def test_exec_success_json(self, gcs_hook_mock_class, mysql_hook_mock_class):
gcs_hook_mock = gcs_hook_mock_class.return_value
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual(JSON_FILENAME.format(0), obj)
- self.assertEqual('application/json', mime_type)
- self.assertFalse(gzip)
+ assert BUCKET == bucket
+ assert JSON_FILENAME.format(0) == obj
+ assert 'application/json' == mime_type
+ assert not gzip
with open(tmp_filename, 'rb') as file:
- self.assertEqual(b''.join(NDJSON_LINES), file.read())
+ assert b''.join(NDJSON_LINES) == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -151,12 +152,12 @@ def test_exec_success_csv(self, gcs_hook_mock_class, mysql_hook_mock_class):
gcs_hook_mock = gcs_hook_mock_class.return_value
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual(CSV_FILENAME.format(0), obj)
- self.assertEqual('text/csv', mime_type)
- self.assertFalse(gzip)
+ assert BUCKET == bucket
+ assert CSV_FILENAME.format(0) == obj
+ assert 'text/csv' == mime_type
+ assert not gzip
with open(tmp_filename, 'rb') as file:
- self.assertEqual(b''.join(CSV_LINES), file.read())
+ assert b''.join(CSV_LINES) == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -186,12 +187,12 @@ def test_exec_success_csv_ensure_utc(self, gcs_hook_mock_class, mysql_hook_mock_
gcs_hook_mock = gcs_hook_mock_class.return_value
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual(CSV_FILENAME.format(0), obj)
- self.assertEqual('text/csv', mime_type)
- self.assertFalse(gzip)
+ assert BUCKET == bucket
+ assert CSV_FILENAME.format(0) == obj
+ assert 'text/csv' == mime_type
+ assert not gzip
with open(tmp_filename, 'rb') as file:
- self.assertEqual(b''.join(CSV_LINES), file.read())
+ assert b''.join(CSV_LINES) == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -221,12 +222,12 @@ def test_exec_success_csv_with_delimiter(self, gcs_hook_mock_class, mysql_hook_m
gcs_hook_mock = gcs_hook_mock_class.return_value
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual(CSV_FILENAME.format(0), obj)
- self.assertEqual('text/csv', mime_type)
- self.assertFalse(gzip)
+ assert BUCKET == bucket
+ assert CSV_FILENAME.format(0) == obj
+ assert 'text/csv' == mime_type
+ assert not gzip
with open(tmp_filename, 'rb') as file:
- self.assertEqual(b''.join(CSV_LINES_PIPE_DELIMITED), file.read())
+ assert b''.join(CSV_LINES_PIPE_DELIMITED) == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -250,11 +251,11 @@ def test_file_splitting(self, gcs_hook_mock_class, mysql_hook_mock_class):
}
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual('application/json', mime_type)
- self.assertFalse(gzip)
+ assert BUCKET == bucket
+ assert 'application/json' == mime_type
+ assert not gzip
with open(tmp_filename, 'rb') as file:
- self.assertEqual(expected_upload[obj], file.read())
+ assert expected_upload[obj] == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -279,9 +280,9 @@ def test_schema_file(self, gcs_hook_mock_class, mysql_hook_mock_class):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disable=unused-argument
if obj == SCHEMA_FILENAME:
- self.assertFalse(gzip)
+ assert not gzip
with open(tmp_filename, 'rb') as file:
- self.assertEqual(b''.join(SCHEMA_JSON), file.read())
+ assert b''.join(SCHEMA_JSON) == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -291,7 +292,7 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disab
op.execute(None)
# once for the file and once for the schema
- self.assertEqual(2, gcs_hook_mock.upload.call_count)
+ assert 2 == gcs_hook_mock.upload.call_count
@mock.patch('airflow.providers.google.cloud.transfers.mysql_to_gcs.MySqlHook')
@mock.patch('airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook')
@@ -305,9 +306,9 @@ def test_schema_file_with_custom_schema(self, gcs_hook_mock_class, mysql_hook_mo
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disable=unused-argument
if obj == SCHEMA_FILENAME:
- self.assertFalse(gzip)
+ assert not gzip
with open(tmp_filename, 'rb') as file:
- self.assertEqual(b''.join(CUSTOM_SCHEMA_JSON), file.read())
+ assert b''.join(CUSTOM_SCHEMA_JSON) == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -322,7 +323,7 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disab
op.execute(None)
# once for the file and once for the schema
- self.assertEqual(2, gcs_hook_mock.upload.call_count)
+ assert 2 == gcs_hook_mock.upload.call_count
@mock.patch('airflow.providers.google.cloud.transfers.mysql_to_gcs.MySqlHook')
@mock.patch('airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook')
@@ -333,7 +334,7 @@ def test_query_with_error(self, mock_gcs_hook, mock_mysql_hook):
op = MySQLToGCSOperator(
task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME, schema_filename=SCHEMA_FILENAME
)
- with self.assertRaises(ProgrammingError):
+ with pytest.raises(ProgrammingError):
op.query()
@mock.patch('airflow.providers.google.cloud.transfers.mysql_to_gcs.MySqlHook')
@@ -345,5 +346,5 @@ def test_execute_with_query_error(self, mock_gcs_hook, mock_mysql_hook):
op = MySQLToGCSOperator(
task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME, schema_filename=SCHEMA_FILENAME
)
- with self.assertRaises(ProgrammingError):
+ with pytest.raises(ProgrammingError):
op.execute(None)
diff --git a/tests/providers/google/cloud/transfers/test_oracle_to_gcs.py b/tests/providers/google/cloud/transfers/test_oracle_to_gcs.py
index b91dbd64ff7b8..13b743151ed48 100644
--- a/tests/providers/google/cloud/transfers/test_oracle_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_oracle_to_gcs.py
@@ -51,10 +51,10 @@ class TestOracleToGoogleCloudStorageOperator(unittest.TestCase):
def test_init(self):
"""Test OracleToGoogleCloudStorageOperator instance is properly initialized."""
op = OracleToGCSOperator(task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME)
- self.assertEqual(op.task_id, TASK_ID)
- self.assertEqual(op.sql, SQL)
- self.assertEqual(op.bucket, BUCKET)
- self.assertEqual(op.filename, JSON_FILENAME)
+ assert op.task_id == TASK_ID
+ assert op.sql == SQL
+ assert op.bucket == BUCKET
+ assert op.filename == JSON_FILENAME
@mock.patch('airflow.providers.google.cloud.transfers.oracle_to_gcs.OracleHook')
@mock.patch('airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook')
@@ -71,12 +71,12 @@ def test_exec_success_json(self, gcs_hook_mock_class, oracle_hook_mock_class):
gcs_hook_mock = gcs_hook_mock_class.return_value
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual(JSON_FILENAME.format(0), obj)
- self.assertEqual('application/json', mime_type)
- self.assertEqual(GZIP, gzip)
+ assert BUCKET == bucket
+ assert JSON_FILENAME.format(0) == obj
+ assert 'application/json' == mime_type
+ assert GZIP == gzip
with open(tmp_filename, 'rb') as file:
- self.assertEqual(b''.join(NDJSON_LINES), file.read())
+ assert b''.join(NDJSON_LINES) == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -100,11 +100,11 @@ def test_file_splitting(self, gcs_hook_mock_class, oracle_hook_mock_class):
}
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual('application/json', mime_type)
- self.assertEqual(GZIP, gzip)
+ assert BUCKET == bucket
+ assert 'application/json' == mime_type
+ assert GZIP == gzip
with open(tmp_filename, 'rb') as file:
- self.assertEqual(expected_upload[obj], file.read())
+ assert expected_upload[obj] == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -130,7 +130,7 @@ def test_schema_file(self, gcs_hook_mock_class, oracle_hook_mock_class):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disable=unused-argument
if obj == SCHEMA_FILENAME:
with open(tmp_filename, 'rb') as file:
- self.assertEqual(b''.join(SCHEMA_JSON), file.read())
+ assert b''.join(SCHEMA_JSON) == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -140,4 +140,4 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disab
op.execute(None)
# once for the file and once for the schema
- self.assertEqual(2, gcs_hook_mock.upload.call_count)
+ assert 2 == gcs_hook_mock.upload.call_count
diff --git a/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py b/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py
index 0b10e7ad8ea0f..a8dfca2b86480 100644
--- a/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py
@@ -76,18 +76,18 @@ def tearDownClass(cls):
def test_init(self):
"""Test PostgresToGoogleCloudStorageOperator instance is properly initialized."""
op = PostgresToGCSOperator(task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=FILENAME)
- self.assertEqual(op.task_id, TASK_ID)
- self.assertEqual(op.sql, SQL)
- self.assertEqual(op.bucket, BUCKET)
- self.assertEqual(op.filename, FILENAME)
+ assert op.task_id == TASK_ID
+ assert op.sql == SQL
+ assert op.bucket == BUCKET
+ assert op.filename == FILENAME
def _assert_uploaded_file_content(self, bucket, obj, tmp_filename, mime_type, gzip):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual(FILENAME.format(0), obj)
- self.assertEqual('application/json', mime_type)
- self.assertFalse(gzip)
+ assert BUCKET == bucket
+ assert FILENAME.format(0) == obj
+ assert 'application/json' == mime_type
+ assert not gzip
with open(tmp_filename, 'rb') as file:
- self.assertEqual(b''.join(NDJSON_LINES), file.read())
+ assert b''.join(NDJSON_LINES) == file.read()
@patch('airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook')
def test_exec_success(self, gcs_hook_mock_class):
@@ -127,11 +127,11 @@ def test_file_splitting(self, gcs_hook_mock_class):
}
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual('application/json', mime_type)
- self.assertFalse(gzip)
+ assert BUCKET == bucket
+ assert 'application/json' == mime_type
+ assert not gzip
with open(tmp_filename, 'rb') as file:
- self.assertEqual(expected_upload[obj], file.read())
+ assert expected_upload[obj] == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -153,7 +153,7 @@ def test_schema_file(self, gcs_hook_mock_class):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disable=unused-argument
if obj == SCHEMA_FILENAME:
with open(tmp_filename, 'rb') as file:
- self.assertEqual(SCHEMA_JSON, file.read())
+ assert SCHEMA_JSON == file.read()
gcs_hook_mock.upload.side_effect = _assert_upload
@@ -163,4 +163,4 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disab
op.execute(None)
# once for the file and once for the schema
- self.assertEqual(2, gcs_hook_mock.upload.call_count)
+ assert 2 == gcs_hook_mock.upload.call_count
diff --git a/tests/providers/google/cloud/transfers/test_presto_to_gcs.py b/tests/providers/google/cloud/transfers/test_presto_to_gcs.py
index 8f1c907e19a29..3eb9f63025b02 100644
--- a/tests/providers/google/cloud/transfers/test_presto_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_presto_to_gcs.py
@@ -56,22 +56,22 @@ def test_init(self):
filename=FILENAME,
impersonation_chain=IMPERSONATION_CHAIN,
)
- self.assertEqual(op.task_id, TASK_ID)
- self.assertEqual(op.sql, SQL)
- self.assertEqual(op.bucket, BUCKET)
- self.assertEqual(op.filename, FILENAME)
- self.assertEqual(op.impersonation_chain, IMPERSONATION_CHAIN)
+ assert op.task_id == TASK_ID
+ assert op.sql == SQL
+ assert op.bucket == BUCKET
+ assert op.filename == FILENAME
+ assert op.impersonation_chain == IMPERSONATION_CHAIN
@patch("airflow.providers.google.cloud.transfers.presto_to_gcs.PrestoHook")
@patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
def test_save_as_json(self, mock_gcs_hook, mock_presto_hook):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual(FILENAME.format(0), obj)
- self.assertEqual("application/json", mime_type)
- self.assertFalse(gzip)
+ assert BUCKET == bucket
+ assert FILENAME.format(0) == obj
+ assert "application/json" == mime_type
+ assert not gzip
with open(tmp_filename, "rb") as file:
- self.assertEqual(b"".join(NDJSON_LINES), file.read())
+ assert b"".join(NDJSON_LINES) == file.read()
mock_gcs_hook.return_value.upload.side_effect = _assert_upload
@@ -121,11 +121,11 @@ def test_save_as_json_with_file_splitting(self, mock_gcs_hook, mock_presto_hook)
}
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual("application/json", mime_type)
- self.assertFalse(gzip)
+ assert BUCKET == bucket
+ assert "application/json" == mime_type
+ assert not gzip
with open(tmp_filename, "rb") as file:
- self.assertEqual(expected_upload[obj], file.read())
+ assert expected_upload[obj] == file.read()
mock_gcs_hook.return_value.upload.side_effect = _assert_upload
@@ -163,7 +163,7 @@ def test_save_as_json_with_schema_file(self, mock_gcs_hook, mock_presto_hook):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disable=unused-argument
if obj == SCHEMA_FILENAME:
with open(tmp_filename, "rb") as file:
- self.assertEqual(SCHEMA_JSON, file.read())
+ assert SCHEMA_JSON == file.read()
mock_gcs_hook.return_value.upload.side_effect = _assert_upload
@@ -194,18 +194,18 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disab
op.execute(None)
# once for the file and once for the schema
- self.assertEqual(2, mock_gcs_hook.return_value.upload.call_count)
+ assert 2 == mock_gcs_hook.return_value.upload.call_count
@patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
@patch("airflow.providers.google.cloud.transfers.presto_to_gcs.PrestoHook")
def test_save_as_csv(self, mock_presto_hook, mock_gcs_hook):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual(FILENAME.format(0), obj)
- self.assertEqual("text/csv", mime_type)
- self.assertFalse(gzip)
+ assert BUCKET == bucket
+ assert FILENAME.format(0) == obj
+ assert "text/csv" == mime_type
+ assert not gzip
with open(tmp_filename, "rb") as file:
- self.assertEqual(b"".join(CSV_LINES), file.read())
+ assert b"".join(CSV_LINES) == file.read()
mock_gcs_hook.return_value.upload.side_effect = _assert_upload
@@ -256,11 +256,11 @@ def test_save_as_csv_with_file_splitting(self, mock_gcs_hook, mock_presto_hook):
}
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
- self.assertEqual(BUCKET, bucket)
- self.assertEqual("text/csv", mime_type)
- self.assertFalse(gzip)
+ assert BUCKET == bucket
+ assert "text/csv" == mime_type
+ assert not gzip
with open(tmp_filename, "rb") as file:
- self.assertEqual(expected_upload[obj], file.read())
+ assert expected_upload[obj] == file.read()
mock_gcs_hook.return_value.upload.side_effect = _assert_upload
@@ -299,7 +299,7 @@ def test_save_as_csv_with_schema_file(self, mock_gcs_hook, mock_presto_hook):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disable=unused-argument
if obj == SCHEMA_FILENAME:
with open(tmp_filename, "rb") as file:
- self.assertEqual(SCHEMA_JSON, file.read())
+ assert SCHEMA_JSON == file.read()
mock_gcs_hook.return_value.upload.side_effect = _assert_upload
@@ -328,4 +328,4 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disab
op.execute(None)
# once for the file and once for the schema
- self.assertEqual(2, mock_gcs_hook.return_value.upload.call_count)
+ assert 2 == mock_gcs_hook.return_value.upload.call_count
diff --git a/tests/providers/google/cloud/transfers/test_s3_to_gcs.py b/tests/providers/google/cloud/transfers/test_s3_to_gcs.py
index 990702fdbdb9a..f7ea825027e1d 100644
--- a/tests/providers/google/cloud/transfers/test_s3_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_s3_to_gcs.py
@@ -46,13 +46,13 @@ def test_init(self):
google_impersonation_chain=IMPERSONATION_CHAIN,
)
- self.assertEqual(operator.task_id, TASK_ID)
- self.assertEqual(operator.bucket, S3_BUCKET)
- self.assertEqual(operator.prefix, S3_PREFIX)
- self.assertEqual(operator.delimiter, S3_DELIMITER)
- self.assertEqual(operator.gcp_conn_id, GCS_CONN_ID)
- self.assertEqual(operator.dest_gcs, GCS_PATH_PREFIX)
- self.assertEqual(operator.google_impersonation_chain, IMPERSONATION_CHAIN)
+ assert operator.task_id == TASK_ID
+ assert operator.bucket == S3_BUCKET
+ assert operator.prefix == S3_PREFIX
+ assert operator.delimiter == S3_DELIMITER
+ assert operator.gcp_conn_id == GCS_CONN_ID
+ assert operator.dest_gcs == GCS_PATH_PREFIX
+ assert operator.google_impersonation_chain == IMPERSONATION_CHAIN
@mock.patch('airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook')
@mock.patch('airflow.providers.amazon.aws.operators.s3_list.S3Hook')
@@ -92,7 +92,7 @@ def test_execute(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook):
)
# we expect MOCK_FILES to be uploaded
- self.assertEqual(sorted(MOCK_FILES), sorted(uploaded_files))
+ assert sorted(MOCK_FILES) == sorted(uploaded_files)
@mock.patch('airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook')
@mock.patch('airflow.providers.amazon.aws.operators.s3_list.S3Hook')
diff --git a/tests/providers/google/cloud/transfers/test_salesforce_to_gcs.py b/tests/providers/google/cloud/transfers/test_salesforce_to_gcs.py
index 5aed4b7535662..cc4ed4ad7497f 100644
--- a/tests/providers/google/cloud/transfers/test_salesforce_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_salesforce_to_gcs.py
@@ -90,4 +90,4 @@ def test_execute(self, mock_make_query, mock_write_object_to_file, mock_upload):
bucket_name=GCS_BUCKET, object_name=GCS_OBJECT_PATH, filename=mock.ANY, gzip=False
)
- self.assertEqual(EXPECTED_GCS_URI, result)
+ assert EXPECTED_GCS_URI == result
diff --git a/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py b/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py
index 03b21d48c085e..4592f030090e2 100644
--- a/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_sftp_to_gcs.py
@@ -21,6 +21,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.transfers.sftp_to_gcs import SFTPToGCSOperator
@@ -213,8 +215,8 @@ def test_execute_more_than_one_wildcard_exception(self, sftp_hook, gcs_hook):
sftp_conn_id=SFTP_CONN_ID,
delegate_to=DELEGATE_TO,
)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
task.execute(None)
- err = cm.exception
- self.assertIn("Only one wildcard '*' is allowed in source_path parameter", str(err))
+ err = ctx.value
+ assert "Only one wildcard '*' is allowed in source_path parameter" in str(err)
diff --git a/tests/providers/google/cloud/utils/test_credentials_provider.py b/tests/providers/google/cloud/utils/test_credentials_provider.py
index c49872c0c29a9..6841306ddc889 100644
--- a/tests/providers/google/cloud/utils/test_credentials_provider.py
+++ b/tests/providers/google/cloud/utils/test_credentials_provider.py
@@ -23,6 +23,7 @@
from unittest import mock
from uuid import uuid4
+import pytest
from google.auth.environment_vars import CREDENTIALS
from parameterized import parameterized
@@ -54,20 +55,17 @@ class TestHelper(unittest.TestCase):
def test_build_gcp_conn_path(self):
value = "test"
conn = build_gcp_conn(key_file_path=value)
- self.assertEqual("google-cloud-platform://?extra__google_cloud_platform__key_path=test", conn)
+ assert "google-cloud-platform://?extra__google_cloud_platform__key_path=test" == conn
def test_build_gcp_conn_scopes(self):
value = ["test", "test2"]
conn = build_gcp_conn(scopes=value)
- self.assertEqual(
- "google-cloud-platform://?extra__google_cloud_platform__scope=test%2Ctest2",
- conn,
- )
+ assert "google-cloud-platform://?extra__google_cloud_platform__scope=test%2Ctest2" == conn
def test_build_gcp_conn_project(self):
value = "test"
conn = build_gcp_conn(project_id=value)
- self.assertEqual("google-cloud-platform://?extra__google_cloud_platform__projects=test", conn)
+ assert "google-cloud-platform://?extra__google_cloud_platform__projects=test" == conn
class TestProvideGcpCredentials(unittest.TestCase):
@@ -83,16 +81,16 @@ def test_provide_gcp_credentials_key_content(self, mock_file):
mock_file_handler.write = string_file.write
with provide_gcp_credentials(key_file_dict=file_dict):
- self.assertEqual(os.environ[CREDENTIALS], file_name)
- self.assertEqual(file_content, string_file.getvalue())
- self.assertEqual(os.environ[CREDENTIALS], ENV_VALUE)
+ assert os.environ[CREDENTIALS] == file_name
+ assert file_content == string_file.getvalue()
+ assert os.environ[CREDENTIALS] == ENV_VALUE
@mock.patch.dict(os.environ, {CREDENTIALS: ENV_VALUE})
def test_provide_gcp_credentials_keep_environment(self):
key_path = "/test/key-path"
with provide_gcp_credentials(key_file_path=key_path):
- self.assertEqual(os.environ[CREDENTIALS], key_path)
- self.assertEqual(os.environ[CREDENTIALS], ENV_VALUE)
+ assert os.environ[CREDENTIALS] == key_path
+ assert os.environ[CREDENTIALS] == ENV_VALUE
class TestProvideGcpConnection(unittest.TestCase):
@@ -105,8 +103,8 @@ def test_provide_gcp_connection(self, mock_builder):
project_id = "project_id"
with provide_gcp_connection(path, scopes, project_id):
mock_builder.assert_called_once_with(key_file_path=path, scopes=scopes, project_id=project_id)
- self.assertEqual(os.environ[AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT], TEMP_VARIABLE)
- self.assertEqual(os.environ[AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT], ENV_VALUE)
+ assert os.environ[AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT] == TEMP_VARIABLE
+ assert os.environ[AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT] == ENV_VALUE
class TestProvideGcpConnAndCredentials(unittest.TestCase):
@@ -122,10 +120,10 @@ def test_provide_gcp_conn_and_credentials(self, mock_builder):
project_id = "project_id"
with provide_gcp_conn_and_credentials(path, scopes, project_id):
mock_builder.assert_called_once_with(key_file_path=path, scopes=scopes, project_id=project_id)
- self.assertEqual(os.environ[AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT], TEMP_VARIABLE)
- self.assertEqual(os.environ[CREDENTIALS], path)
- self.assertEqual(os.environ[AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT], ENV_VALUE)
- self.assertEqual(os.environ[CREDENTIALS], ENV_VALUE)
+ assert os.environ[AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT] == TEMP_VARIABLE
+ assert os.environ[CREDENTIALS] == path
+ assert os.environ[AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT] == ENV_VALUE
+ assert os.environ[CREDENTIALS] == ENV_VALUE
class TestGetGcpCredentialsAndProjectId(unittest.TestCase):
@@ -140,15 +138,12 @@ def test_get_credentials_and_project_id_with_default_auth(self, mock_auth_defaul
with self.assertLogs() as cm:
result = get_credentials_and_project_id()
mock_auth_default.assert_called_once_with(scopes=None)
- self.assertEqual(("CREDENTIALS", "PROJECT_ID"), result)
- self.assertEqual(
- [
- 'INFO:airflow.providers.google.cloud.utils.credentials_provider._CredentialProvider:Getting '
- 'connection using `google.auth.default()` since no key file is defined for '
- 'hook.'
- ],
- cm.output,
- )
+ assert ("CREDENTIALS", "PROJECT_ID") == result
+ assert [
+ 'INFO:airflow.providers.google.cloud.utils.credentials_provider._CredentialProvider:Getting '
+ 'connection using `google.auth.default()` since no key file is defined for '
+ 'hook.'
+ ] == cm.output
@mock.patch('google.auth.default')
def test_get_credentials_and_project_id_with_default_auth_and_delegate(self, mock_auth_default):
@@ -158,7 +153,7 @@ def test_get_credentials_and_project_id_with_default_auth_and_delegate(self, moc
result = get_credentials_and_project_id(delegate_to="USER")
mock_auth_default.assert_called_once_with(scopes=None)
mock_credentials.with_subject.assert_called_once_with("USER")
- self.assertEqual((mock_credentials.with_subject.return_value, self.test_project_id), result)
+ assert (mock_credentials.with_subject.return_value, self.test_project_id) == result
@parameterized.expand([(['scope1'],), (['scope1', 'scope2'],)])
@mock.patch('google.auth.default')
@@ -168,7 +163,7 @@ def test_get_credentials_and_project_id_with_default_auth_and_scopes(self, scope
result = get_credentials_and_project_id(scopes=scopes)
mock_auth_default.assert_called_once_with(scopes=scopes)
- self.assertEqual(mock_auth_default.return_value, result)
+ assert mock_auth_default.return_value == result
@mock.patch(
'airflow.providers.google.cloud.utils.credentials_provider.' 'impersonated_credentials.Credentials'
@@ -190,7 +185,7 @@ def test_get_credentials_and_project_id_with_default_auth_and_target_principal(
delegates=None,
target_scopes=None,
)
- self.assertEqual((mock_impersonated_credentials.return_value, ANOTHER_PROJECT_ID), result)
+ assert (mock_impersonated_credentials.return_value, ANOTHER_PROJECT_ID) == result
@mock.patch(
'airflow.providers.google.cloud.utils.credentials_provider.' 'impersonated_credentials.Credentials'
@@ -213,7 +208,7 @@ def test_get_credentials_and_project_id_with_default_auth_and_scopes_and_target_
delegates=None,
target_scopes=['scope1', 'scope2'],
)
- self.assertEqual((mock_impersonated_credentials.return_value, self.test_project_id), result)
+ assert (mock_impersonated_credentials.return_value, self.test_project_id) == result
@mock.patch(
'airflow.providers.google.cloud.utils.credentials_provider.' 'impersonated_credentials.Credentials'
@@ -236,7 +231,7 @@ def test_get_credentials_and_project_id_with_default_auth_and_target_principal_a
delegates=[ACCOUNT_1_SAME_PROJECT, ACCOUNT_2_SAME_PROJECT],
target_scopes=None,
)
- self.assertEqual((mock_impersonated_credentials.return_value, ANOTHER_PROJECT_ID), result)
+ assert (mock_impersonated_credentials.return_value, ANOTHER_PROJECT_ID) == result
@mock.patch(
'google.oauth2.service_account.Credentials.from_service_account_file',
@@ -246,18 +241,15 @@ def test_get_credentials_and_project_id_with_service_account_file(self, mock_fro
with self.assertLogs(level="DEBUG") as cm:
result = get_credentials_and_project_id(key_path=self.test_key_file)
mock_from_service_account_file.assert_called_once_with(self.test_key_file, scopes=None)
- self.assertEqual((mock_from_service_account_file.return_value, self.test_project_id), result)
- self.assertEqual(
- [
- 'DEBUG:airflow.providers.google.cloud.utils.credentials_provider._CredentialProvider:Getting '
- 'connection using JSON key file KEY_PATH.json'
- ],
- cm.output,
- )
+ assert (mock_from_service_account_file.return_value, self.test_project_id) == result
+ assert [
+ 'DEBUG:airflow.providers.google.cloud.utils.credentials_provider._CredentialProvider:Getting '
+ 'connection using JSON key file KEY_PATH.json'
+ ] == cm.output
@parameterized.expand([("p12", "path/to/file.p12"), ("unknown", "incorrect_file.ext")])
def test_get_credentials_and_project_id_with_service_account_file_and_non_valid_key(self, _, file):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
get_credentials_and_project_id(key_path=file)
@mock.patch(
@@ -269,20 +261,18 @@ def test_get_credentials_and_project_id_with_service_account_info(self, mock_fro
with self.assertLogs(level="DEBUG") as cm:
result = get_credentials_and_project_id(keyfile_dict=service_account)
mock_from_service_account_info.assert_called_once_with(service_account, scopes=None)
- self.assertEqual((mock_from_service_account_info.return_value, self.test_project_id), result)
- self.assertEqual(
- [
- 'DEBUG:airflow.providers.google.cloud.utils.credentials_provider._CredentialProvider:Getting '
- 'connection using JSON Dict'
- ],
- cm.output,
- )
+ assert (mock_from_service_account_info.return_value, self.test_project_id) == result
+ assert [
+ 'DEBUG:airflow.providers.google.cloud.utils.credentials_provider._CredentialProvider:Getting '
+ 'connection using JSON Dict'
+ ] == cm.output
def test_get_credentials_and_project_id_with_mutually_exclusive_configuration(
self,
):
- with self.assertRaisesRegex(
- AirflowException, re.escape('The `keyfile_dict` and `key_path` fields are mutually exclusive.')
+ with pytest.raises(
+ AirflowException,
+ match=re.escape('The `keyfile_dict` and `key_path` fields are mutually exclusive.'),
):
get_credentials_and_project_id(key_path='KEY.json', keyfile_dict={'private_key': 'PRIVATE_KEY'})
@@ -295,18 +285,18 @@ def test_get_credentials_and_project_id_with_mutually_exclusive_configuration(
)
def test_disable_logging(self, mock_default, mock_info, mock_file):
# assert not logs
- with self.assertRaises(AssertionError), self.assertLogs(level="DEBUG"):
+ with pytest.raises(AssertionError), self.assertLogs(level="DEBUG"):
get_credentials_and_project_id(disable_logging=True)
# assert not logs
- with self.assertRaises(AssertionError), self.assertLogs(level="DEBUG"):
+ with pytest.raises(AssertionError), self.assertLogs(level="DEBUG"):
get_credentials_and_project_id(
keyfile_dict={'private_key': 'PRIVATE_KEY'},
disable_logging=True,
)
# assert not logs
- with self.assertRaises(AssertionError), self.assertLogs(level="DEBUG"):
+ with pytest.raises(AssertionError), self.assertLogs(level="DEBUG"):
get_credentials_and_project_id(
key_path='KEY.json',
disable_logging=True,
@@ -315,7 +305,7 @@ def test_disable_logging(self, mock_default, mock_info, mock_file):
class TestGetScopes(unittest.TestCase):
def test_get_scopes_with_default(self):
- self.assertEqual(_get_scopes(), _DEFAULT_SCOPES)
+ assert _get_scopes() == _DEFAULT_SCOPES
@parameterized.expand(
[
@@ -324,12 +314,12 @@ def test_get_scopes_with_default(self):
]
)
def test_get_scopes_with_input(self, _, scopes_str, scopes):
- self.assertEqual(_get_scopes(scopes_str), scopes)
+ assert _get_scopes(scopes_str) == scopes
class TestGetTargetPrincipalAndDelegates(unittest.TestCase):
def test_get_target_principal_and_delegates_no_argument(self):
- self.assertEqual(_get_target_principal_and_delegates(), (None, None))
+ assert _get_target_principal_and_delegates() == (None, None)
@parameterized.expand(
[
@@ -346,20 +336,15 @@ def test_get_target_principal_and_delegates_no_argument(self):
def test_get_target_principal_and_delegates_with_input(
self, _, impersonation_chain, target_principal_and_delegates
):
- self.assertEqual(
- _get_target_principal_and_delegates(impersonation_chain), target_principal_and_delegates
- )
+ assert _get_target_principal_and_delegates(impersonation_chain) == target_principal_and_delegates
class TestGetProjectIdFromServiceAccountEmail(unittest.TestCase):
def test_get_project_id_from_service_account_email(
self,
):
- self.assertEqual(
- _get_project_id_from_service_account_email(ACCOUNT_3_ANOTHER_PROJECT),
- ANOTHER_PROJECT_ID,
- )
+ assert _get_project_id_from_service_account_email(ACCOUNT_3_ANOTHER_PROJECT) == ANOTHER_PROJECT_ID
def test_get_project_id_from_service_account_email_wrong_input(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
_get_project_id_from_service_account_email("ACCOUNT_1")
diff --git a/tests/providers/google/cloud/utils/test_field_sanitizer.py b/tests/providers/google/cloud/utils/test_field_sanitizer.py
index f0de46ef7c1d4..c32e0ad89b7f0 100644
--- a/tests/providers/google/cloud/utils/test_field_sanitizer.py
+++ b/tests/providers/google/cloud/utils/test_field_sanitizer.py
@@ -18,6 +18,8 @@
import unittest
from copy import deepcopy
+import pytest
+
from airflow.providers.google.cloud.utils.field_sanitizer import GcpBodyFieldSanitizer
@@ -29,7 +31,7 @@ def test_sanitize_should_sanitize_empty_body_and_fields(self):
sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize)
sanitizer.sanitize(body)
- self.assertEqual({}, body)
+ assert {} == body
def test_sanitize_should_not_fail_with_none_body(self):
body = None
@@ -38,7 +40,7 @@ def test_sanitize_should_not_fail_with_none_body(self):
sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize)
sanitizer.sanitize(body)
- self.assertIsNone(body)
+ assert body is None
def test_sanitize_should_fail_with_none_fields(self):
body = {}
@@ -46,7 +48,7 @@ def test_sanitize_should_fail_with_none_fields(self):
sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize)
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
sanitizer.sanitize(body)
def test_sanitize_should_not_fail_if_field_is_absent_in_body(self):
@@ -56,7 +58,7 @@ def test_sanitize_should_not_fail_if_field_is_absent_in_body(self):
sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize)
sanitizer.sanitize(body)
- self.assertEqual({}, body)
+ assert {} == body
def test_sanitize_should_not_remove_fields_for_incorrect_specification(self):
actual_body = [
@@ -70,7 +72,7 @@ def test_sanitize_should_not_remove_fields_for_incorrect_specification(self):
sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize)
sanitizer.sanitize(body)
- self.assertEqual(actual_body, body)
+ assert actual_body == body
def test_sanitize_should_remove_all_fields_from_root_level(self):
body = {"kind": "compute#instanceTemplate", "name": "instance"}
@@ -79,7 +81,7 @@ def test_sanitize_should_remove_all_fields_from_root_level(self):
sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize)
sanitizer.sanitize(body)
- self.assertEqual({"name": "instance"}, body)
+ assert {"name": "instance"} == body
def test_sanitize_should_remove_for_multiple_fields_from_root_level(self):
body = {"kind": "compute#instanceTemplate", "name": "instance"}
@@ -88,7 +90,7 @@ def test_sanitize_should_remove_for_multiple_fields_from_root_level(self):
sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize)
sanitizer.sanitize(body)
- self.assertEqual({}, body)
+ assert {} == body
def test_sanitize_should_remove_all_fields_in_a_list_value(self):
body = {
@@ -103,16 +105,13 @@ def test_sanitize_should_remove_all_fields_in_a_list_value(self):
sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize)
sanitizer.sanitize(body)
- self.assertEqual(
- {
- "fields": [
- {"name": "instance"},
- {"name": "instance1"},
- {"name": "instance2"},
- ]
- },
- body,
- )
+ assert {
+ "fields": [
+ {"name": "instance"},
+ {"name": "instance1"},
+ {"name": "instance2"},
+ ]
+ } == body
def test_sanitize_should_remove_all_fields_in_any_nested_body(self):
fields_to_sanitize = [
@@ -145,19 +144,16 @@ def test_sanitize_should_remove_all_fields_in_any_nested_body(self):
sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize)
sanitizer.sanitize(body)
- self.assertEqual(
- {
- "name": "instance",
- "properties": {
- "disks": [
- {"name": "a", "type": "PERSISTENT", "mode": "READ_WRITE"},
- {"name": "b", "type": "PERSISTENT", "mode": "READ_WRITE"},
- ],
- "metadata": {"fingerprint": "GDPUYxlwHe4="},
- },
+ assert {
+ "name": "instance",
+ "properties": {
+ "disks": [
+ {"name": "a", "type": "PERSISTENT", "mode": "READ_WRITE"},
+ {"name": "b", "type": "PERSISTENT", "mode": "READ_WRITE"},
+ ],
+ "metadata": {"fingerprint": "GDPUYxlwHe4="},
},
- body,
- )
+ } == body
def test_sanitize_should_not_fail_if_specification_has_none_value(self):
fields_to_sanitize = [
@@ -171,7 +167,7 @@ def test_sanitize_should_not_fail_if_specification_has_none_value(self):
sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize)
sanitizer.sanitize(body)
- self.assertEqual({"name": "instance", "properties": {"disks": None}}, body)
+ assert {"name": "instance", "properties": {"disks": None}} == body
def test_sanitize_should_not_fail_if_no_specification_matches(self):
fields_to_sanitize = [
@@ -184,7 +180,7 @@ def test_sanitize_should_not_fail_if_no_specification_matches(self):
sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize)
sanitizer.sanitize(body)
- self.assertEqual({"name": "instance", "properties": {"disks": None}}, body)
+ assert {"name": "instance", "properties": {"disks": None}} == body
def test_sanitize_should_not_fail_if_type_in_body_do_not_match_with_specification(self):
fields_to_sanitize = [
@@ -197,4 +193,4 @@ def test_sanitize_should_not_fail_if_type_in_body_do_not_match_with_specificatio
sanitizer = GcpBodyFieldSanitizer(fields_to_sanitize)
sanitizer.sanitize(body)
- self.assertEqual({"name": "instance", "properties": {"disks": 1}}, body)
+ assert {"name": "instance", "properties": {"disks": 1}} == body
diff --git a/tests/providers/google/cloud/utils/test_field_validator.py b/tests/providers/google/cloud/utils/test_field_validator.py
index 4c439c5dcd80a..43555d93e2067 100644
--- a/tests/providers/google/cloud/utils/test_field_validator.py
+++ b/tests/providers/google/cloud/utils/test_field_validator.py
@@ -17,6 +17,8 @@
import unittest
+import pytest
+
from airflow.providers.google.cloud.utils.field_validator import (
GcpBodyFieldValidator,
GcpFieldValidationException,
@@ -31,7 +33,7 @@ def test_validate_should_not_raise_exception_if_field_and_body_are_both_empty(se
validator = GcpBodyFieldValidator(specification, 'v1')
- self.assertIsNone(validator.validate(body))
+ assert validator.validate(body) is None
def test_validate_should_fail_if_body_is_none(self):
specification = []
@@ -39,7 +41,7 @@ def test_validate_should_fail_if_body_is_none(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- with self.assertRaises(AttributeError):
+ with pytest.raises(AttributeError):
validator.validate(body)
def test_validate_should_fail_if_specification_is_none(self):
@@ -48,7 +50,7 @@ def test_validate_should_fail_if_specification_is_none(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
validator.validate(body)
def test_validate_should_raise_exception_name_attribute_is_missing_from_specs(self):
@@ -57,7 +59,7 @@ def test_validate_should_raise_exception_name_attribute_is_missing_from_specs(se
validator = GcpBodyFieldValidator(specification, 'v1')
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
validator.validate(body)
def test_validate_should_raise_exception_if_field_is_not_present(self):
@@ -66,7 +68,7 @@ def test_validate_should_raise_exception_if_field_is_not_present(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- with self.assertRaises(GcpFieldValidationException):
+ with pytest.raises(GcpFieldValidationException):
validator.validate(body)
def test_validate_should_validate_a_single_field(self):
@@ -75,7 +77,7 @@ def test_validate_should_validate_a_single_field(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- self.assertIsNone(validator.validate(body))
+ assert validator.validate(body) is None
def test_validate_should_fail_if_body_is_not_a_dict(self):
specification = [dict(name="name", allow_empty=False)]
@@ -83,7 +85,7 @@ def test_validate_should_fail_if_body_is_not_a_dict(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- with self.assertRaises(AttributeError):
+ with pytest.raises(AttributeError):
validator.validate(body)
def test_validate_should_fail_for_set_allow_empty_when_field_is_none(self):
@@ -92,7 +94,7 @@ def test_validate_should_fail_for_set_allow_empty_when_field_is_none(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- with self.assertRaises(GcpFieldValidationException):
+ with pytest.raises(GcpFieldValidationException):
validator.validate(body)
def test_validate_should_interpret_allow_empty_clause(self):
@@ -101,7 +103,7 @@ def test_validate_should_interpret_allow_empty_clause(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- self.assertIsNone(validator.validate(body))
+ assert validator.validate(body) is None
def test_validate_should_raise_if_empty_clause_is_false(self):
specification = [dict(name="name", allow_empty=False)]
@@ -109,7 +111,7 @@ def test_validate_should_raise_if_empty_clause_is_false(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- with self.assertRaises(GcpFieldValidationException):
+ with pytest.raises(GcpFieldValidationException):
validator.validate(body)
def test_validate_should_raise_if_version_mismatch_is_found(self):
@@ -126,7 +128,7 @@ def test_validate_should_interpret_optional_irrespective_of_allow_empty(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- self.assertIsNone(validator.validate(body))
+ assert validator.validate(body) is None
def test_validate_should_interpret_optional_clause(self):
specification = [dict(name="name", allow_empty=False, optional=True)]
@@ -134,7 +136,7 @@ def test_validate_should_interpret_optional_clause(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- self.assertIsNone(validator.validate(body))
+ assert validator.validate(body) is None
def test_validate_should_raise_exception_if_optional_clause_is_false_and_field_not_present(self):
specification = [dict(name="name", allow_empty=False, optional=False)]
@@ -142,7 +144,7 @@ def test_validate_should_raise_exception_if_optional_clause_is_false_and_field_n
validator = GcpBodyFieldValidator(specification, 'v1')
- with self.assertRaises(GcpFieldValidationException):
+ with pytest.raises(GcpFieldValidationException):
validator.validate(body)
def test_validate_should_interpret_dict_type(self):
@@ -151,7 +153,7 @@ def test_validate_should_interpret_dict_type(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- self.assertIsNone(validator.validate(body))
+ assert validator.validate(body) is None
def test_validate_should_fail_if_value_is_not_dict_as_per_specs(self):
specification = [dict(name="labels", optional=True, type="dict")]
@@ -159,7 +161,7 @@ def test_validate_should_fail_if_value_is_not_dict_as_per_specs(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- with self.assertRaises(GcpFieldValidationException):
+ with pytest.raises(GcpFieldValidationException):
validator.validate(body)
def test_validate_should_not_allow_both_type_and_allow_empty_in_a_spec(self):
@@ -168,7 +170,7 @@ def test_validate_should_not_allow_both_type_and_allow_empty_in_a_spec(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- with self.assertRaises(GcpValidationSpecificationException):
+ with pytest.raises(GcpValidationSpecificationException):
validator.validate(body)
def test_validate_should_allow_type_and_optional_in_a_spec(self):
@@ -177,7 +179,7 @@ def test_validate_should_allow_type_and_optional_in_a_spec(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- self.assertIsNone(validator.validate(body))
+ assert validator.validate(body) is None
def test_validate_should_fail_if_union_field_is_not_found(self):
specification = [
@@ -193,7 +195,7 @@ def test_validate_should_fail_if_union_field_is_not_found(self):
body = {}
validator = GcpBodyFieldValidator(specification, 'v1')
- self.assertIsNone(validator.validate(body))
+ assert validator.validate(body) is None
def test_validate_should_fail_if_there_is_no_nested_field_for_union(self):
specification = [dict(name="an_union", type="union", optional=False, fields=[])]
@@ -201,7 +203,7 @@ def test_validate_should_fail_if_there_is_no_nested_field_for_union(self):
validator = GcpBodyFieldValidator(specification, 'v1')
- with self.assertRaises(GcpValidationSpecificationException):
+ with pytest.raises(GcpValidationSpecificationException):
validator.validate(body)
def test_validate_should_interpret_union_with_one_field(self):
@@ -217,7 +219,7 @@ def test_validate_should_interpret_union_with_one_field(self):
body = {"variant_1": "abc", "variant_2": "def"}
validator = GcpBodyFieldValidator(specification, 'v1')
- self.assertIsNone(validator.validate(body))
+ assert validator.validate(body) is None
def test_validate_should_fail_if_both_field_of_union_is_present(self):
specification = [
@@ -233,7 +235,7 @@ def test_validate_should_fail_if_both_field_of_union_is_present(self):
body = {"variant_1": "abc", "variant_2": "def"}
validator = GcpBodyFieldValidator(specification, 'v1')
- with self.assertRaises(GcpFieldValidationException):
+ with pytest.raises(GcpFieldValidationException):
validator.validate(body)
def test_validate_should_validate_when_value_matches_regex(self):
@@ -249,7 +251,7 @@ def test_validate_should_validate_when_value_matches_regex(self):
body = {"variant_1": "12"}
validator = GcpBodyFieldValidator(specification, 'v1')
- self.assertIsNone(validator.validate(body))
+ assert validator.validate(body) is None
def test_validate_should_fail_when_value_does_not_match_regex(self):
specification = [
@@ -264,7 +266,7 @@ def test_validate_should_fail_when_value_does_not_match_regex(self):
body = {"variant_1": "abc"}
validator = GcpBodyFieldValidator(specification, 'v1')
- with self.assertRaises(GcpFieldValidationException):
+ with pytest.raises(GcpFieldValidationException):
validator.validate(body)
def test_validate_should_raise_if_custom_validation_is_not_true(self):
@@ -276,7 +278,7 @@ def _int_equal_to_zero(value):
body = {"availableMemoryMb": 1}
validator = GcpBodyFieldValidator(specification, 'v1')
- with self.assertRaises(GcpFieldValidationException):
+ with pytest.raises(GcpFieldValidationException):
validator.validate(body)
def test_validate_should_not_raise_if_custom_validation_is_true(self):
@@ -288,7 +290,7 @@ def _int_equal_to_zero(value):
body = {"availableMemoryMb": 0}
validator = GcpBodyFieldValidator(specification, 'v1')
- self.assertIsNone(validator.validate(body))
+ assert validator.validate(body) is None
def test_validate_should_validate_group_of_specs(self):
specification = [
diff --git a/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py b/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py
index 78a41bb323bd7..289eb439b639a 100644
--- a/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py
+++ b/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py
@@ -22,6 +22,7 @@
from unittest import mock
import dill
+import pytest
from airflow.exceptions import AirflowException
from airflow.models import DAG
@@ -113,23 +114,23 @@ def test_create_evaluate_ops(self, mock_dataflow, mock_python):
mock_dataflow.assert_called_once_with(evaluate_prediction)
mock_python.assert_called_once_with(evaluate_summary)
- self.assertEqual(TASK_PREFIX_PREDICTION, evaluate_prediction.task_id)
- self.assertEqual(PROJECT_ID, evaluate_prediction._project_id)
- self.assertEqual(BATCH_PREDICTION_JOB_ID, evaluate_prediction._job_id)
- self.assertEqual(REGION, evaluate_prediction._region)
- self.assertEqual(DATA_FORMAT, evaluate_prediction._data_format)
- self.assertEqual(INPUT_PATHS, evaluate_prediction._input_paths)
- self.assertEqual(PREDICTION_PATH, evaluate_prediction._output_path)
- self.assertEqual(MODEL_URI, evaluate_prediction._uri)
+ assert TASK_PREFIX_PREDICTION == evaluate_prediction.task_id
+ assert PROJECT_ID == evaluate_prediction._project_id
+ assert BATCH_PREDICTION_JOB_ID == evaluate_prediction._job_id
+ assert REGION == evaluate_prediction._region
+ assert DATA_FORMAT == evaluate_prediction._data_format
+ assert INPUT_PATHS == evaluate_prediction._input_paths
+ assert PREDICTION_PATH == evaluate_prediction._output_path
+ assert MODEL_URI == evaluate_prediction._uri
- self.assertEqual(TASK_PREFIX_SUMMARY, evaluate_summary.task_id)
- self.assertEqual(DATAFLOW_OPTIONS, evaluate_summary.dataflow_default_options)
- self.assertEqual(PREDICTION_PATH, evaluate_summary.options["prediction_path"])
- self.assertEqual(METRIC_FN_ENCODED, evaluate_summary.options["metric_fn_encoded"])
- self.assertEqual(METRIC_KEYS_EXPECTED, evaluate_summary.options["metric_keys"])
+ assert TASK_PREFIX_SUMMARY == evaluate_summary.task_id
+ assert DATAFLOW_OPTIONS == evaluate_summary.dataflow_default_options
+ assert PREDICTION_PATH == evaluate_summary.options["prediction_path"]
+ assert METRIC_FN_ENCODED == evaluate_summary.options["metric_fn_encoded"]
+ assert METRIC_KEYS_EXPECTED == evaluate_summary.options["metric_keys"]
- self.assertEqual(TASK_PREFIX_VALIDATION, evaluate_validation.task_id)
- self.assertEqual(PREDICTION_PATH, evaluate_validation.templates_dict["prediction_path"])
+ assert TASK_PREFIX_VALIDATION == evaluate_validation.task_id
+ assert PREDICTION_PATH == evaluate_validation.templates_dict["prediction_path"]
@mock.patch.object(PythonOperator, "set_upstream")
@mock.patch.object(DataflowCreatePythonJobOperator, "set_upstream")
@@ -154,24 +155,24 @@ def test_create_evaluate_ops_model_and_version_name(self, mock_dataflow, mock_py
mock_dataflow.assert_called_once_with(evaluate_prediction)
mock_python.assert_called_once_with(evaluate_summary)
- self.assertEqual(TASK_PREFIX_PREDICTION, evaluate_prediction.task_id)
- self.assertEqual(PROJECT_ID, evaluate_prediction._project_id)
- self.assertEqual(BATCH_PREDICTION_JOB_ID, evaluate_prediction._job_id)
- self.assertEqual(REGION, evaluate_prediction._region)
- self.assertEqual(DATA_FORMAT, evaluate_prediction._data_format)
- self.assertEqual(INPUT_PATHS, evaluate_prediction._input_paths)
- self.assertEqual(PREDICTION_PATH, evaluate_prediction._output_path)
- self.assertEqual(MODEL_NAME, evaluate_prediction._model_name)
- self.assertEqual(VERSION_NAME, evaluate_prediction._version_name)
-
- self.assertEqual(TASK_PREFIX_SUMMARY, evaluate_summary.task_id)
- self.assertEqual(DATAFLOW_OPTIONS, evaluate_summary.dataflow_default_options)
- self.assertEqual(PREDICTION_PATH, evaluate_summary.options["prediction_path"])
- self.assertEqual(METRIC_FN_ENCODED, evaluate_summary.options["metric_fn_encoded"])
- self.assertEqual(METRIC_KEYS_EXPECTED, evaluate_summary.options["metric_keys"])
-
- self.assertEqual(TASK_PREFIX_VALIDATION, evaluate_validation.task_id)
- self.assertEqual(PREDICTION_PATH, evaluate_validation.templates_dict["prediction_path"])
+ assert TASK_PREFIX_PREDICTION == evaluate_prediction.task_id
+ assert PROJECT_ID == evaluate_prediction._project_id
+ assert BATCH_PREDICTION_JOB_ID == evaluate_prediction._job_id
+ assert REGION == evaluate_prediction._region
+ assert DATA_FORMAT == evaluate_prediction._data_format
+ assert INPUT_PATHS == evaluate_prediction._input_paths
+ assert PREDICTION_PATH == evaluate_prediction._output_path
+ assert MODEL_NAME == evaluate_prediction._model_name
+ assert VERSION_NAME == evaluate_prediction._version_name
+
+ assert TASK_PREFIX_SUMMARY == evaluate_summary.task_id
+ assert DATAFLOW_OPTIONS == evaluate_summary.dataflow_default_options
+ assert PREDICTION_PATH == evaluate_summary.options["prediction_path"]
+ assert METRIC_FN_ENCODED == evaluate_summary.options["metric_fn_encoded"]
+ assert METRIC_KEYS_EXPECTED == evaluate_summary.options["metric_keys"]
+
+ assert TASK_PREFIX_VALIDATION == evaluate_validation.task_id
+ assert PREDICTION_PATH == evaluate_validation.templates_dict["prediction_path"]
@mock.patch.object(PythonOperator, "set_upstream")
@mock.patch.object(DataflowCreatePythonJobOperator, "set_upstream")
@@ -192,24 +193,24 @@ def test_create_evaluate_ops_dag(self, mock_dataflow, mock_python):
mock_dataflow.assert_called_once_with(evaluate_prediction)
mock_python.assert_called_once_with(evaluate_summary)
- self.assertEqual(TASK_PREFIX_PREDICTION, evaluate_prediction.task_id)
- self.assertEqual(PROJECT_ID, evaluate_prediction._project_id)
- self.assertEqual(BATCH_PREDICTION_JOB_ID, evaluate_prediction._job_id)
- self.assertEqual(REGION, evaluate_prediction._region)
- self.assertEqual(DATA_FORMAT, evaluate_prediction._data_format)
- self.assertEqual(INPUT_PATHS, evaluate_prediction._input_paths)
- self.assertEqual(PREDICTION_PATH, evaluate_prediction._output_path)
- self.assertEqual(MODEL_NAME, evaluate_prediction._model_name)
- self.assertEqual(VERSION_NAME, evaluate_prediction._version_name)
-
- self.assertEqual(TASK_PREFIX_SUMMARY, evaluate_summary.task_id)
- self.assertEqual(DATAFLOW_OPTIONS, evaluate_summary.dataflow_default_options)
- self.assertEqual(PREDICTION_PATH, evaluate_summary.options["prediction_path"])
- self.assertEqual(METRIC_FN_ENCODED, evaluate_summary.options["metric_fn_encoded"])
- self.assertEqual(METRIC_KEYS_EXPECTED, evaluate_summary.options["metric_keys"])
-
- self.assertEqual(TASK_PREFIX_VALIDATION, evaluate_validation.task_id)
- self.assertEqual(PREDICTION_PATH, evaluate_validation.templates_dict["prediction_path"])
+ assert TASK_PREFIX_PREDICTION == evaluate_prediction.task_id
+ assert PROJECT_ID == evaluate_prediction._project_id
+ assert BATCH_PREDICTION_JOB_ID == evaluate_prediction._job_id
+ assert REGION == evaluate_prediction._region
+ assert DATA_FORMAT == evaluate_prediction._data_format
+ assert INPUT_PATHS == evaluate_prediction._input_paths
+ assert PREDICTION_PATH == evaluate_prediction._output_path
+ assert MODEL_NAME == evaluate_prediction._model_name
+ assert VERSION_NAME == evaluate_prediction._version_name
+
+ assert TASK_PREFIX_SUMMARY == evaluate_summary.task_id
+ assert DATAFLOW_OPTIONS == evaluate_summary.dataflow_default_options
+ assert PREDICTION_PATH == evaluate_summary.options["prediction_path"]
+ assert METRIC_FN_ENCODED == evaluate_summary.options["metric_fn_encoded"]
+ assert METRIC_KEYS_EXPECTED == evaluate_summary.options["metric_keys"]
+
+ assert TASK_PREFIX_VALIDATION == evaluate_validation.task_id
+ assert PREDICTION_PATH == evaluate_validation.templates_dict["prediction_path"]
@mock.patch.object(GCSHook, "download")
@mock.patch.object(PythonOperator, "set_upstream")
@@ -233,27 +234,25 @@ def test_apply_validate_fn(self, mock_dataflow, mock_python, mock_download):
mock_download.return_value = json.dumps({"err": 0.3, "mse": 0.04, "count": 1100})
templates_dict = {"prediction_path": PREDICTION_PATH}
- with self.assertRaises(ValueError) as context:
+ with pytest.raises(ValueError) as ctx:
evaluate_validation.python_callable(templates_dict=templates_dict)
- self.assertEqual(
- "Too high err>0.2; summary={'err': 0.3, 'mse': 0.04, 'count': 1100}", str(context.exception)
- )
+ assert "Too high err>0.2; summary={'err': 0.3, 'mse': 0.04, 'count': 1100}" == str(ctx.value)
mock_download.assert_called_once_with("path", "to/output/predictions.json/prediction.summary.json")
invalid_prediction_paths = ["://path/to/output/predictions.json", "gs://", ""]
for path in invalid_prediction_paths:
templates_dict = {"prediction_path": path}
- with self.assertRaises(ValueError) as context:
+ with pytest.raises(ValueError) as ctx:
evaluate_validation.python_callable(templates_dict=templates_dict)
- self.assertEqual("Wrong format prediction_path:", str(context.exception)[:29])
+ assert "Wrong format prediction_path:" == str(ctx.value)[:29]
def test_invalid_task_prefix(self):
invalid_task_prefix_values = ["test-task-prefix&", "~test-task-prefix", "test-task(-prefix"]
for invalid_task_prefix_value in invalid_task_prefix_values:
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
create_evaluate_ops(
task_prefix=invalid_task_prefix_value,
data_format=DATA_FORMAT,
@@ -264,7 +263,7 @@ def test_invalid_task_prefix(self):
)
def test_non_callable_metric_fn(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
create_evaluate_ops(
task_prefix=TASK_PREFIX,
data_format=DATA_FORMAT,
@@ -275,7 +274,7 @@ def test_non_callable_metric_fn(self):
)
def test_non_callable_validate_fn(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
create_evaluate_ops(
task_prefix=TASK_PREFIX,
data_format=DATA_FORMAT,
diff --git a/tests/providers/google/cloud/utils/test_mlengine_prediction_summary.py b/tests/providers/google/cloud/utils/test_mlengine_prediction_summary.py
index 1fb33c031a321..13356f55e045b 100644
--- a/tests/providers/google/cloud/utils/test_mlengine_prediction_summary.py
+++ b/tests/providers/google/cloud/utils/test_mlengine_prediction_summary.py
@@ -21,6 +21,7 @@
from unittest import mock
import dill
+import pytest
try:
from airflow.providers.google.cloud.utils import mlengine_prediction_summary
@@ -31,10 +32,10 @@
class TestJsonCode(unittest.TestCase):
def test_encode(self):
- self.assertEqual(b'{"a": 1}', mlengine_prediction_summary.JsonCoder.encode({'a': 1}))
+ assert b'{"a": 1}' == mlengine_prediction_summary.JsonCoder.encode({'a': 1})
def test_decode(self):
- self.assertEqual({'a': 1}, mlengine_prediction_summary.JsonCoder.decode('{"a": 1}'))
+ assert {'a': 1} == mlengine_prediction_summary.JsonCoder.decode('{"a": 1}')
class TestMakeSummary(unittest.TestCase):
@@ -42,17 +43,17 @@ def test_make_summary(self):
print(mlengine_prediction_summary.MakeSummary(1, lambda x: x, []))
def test_run_without_all_arguments_should_raise_exception(self):
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
mlengine_prediction_summary.run()
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
mlengine_prediction_summary.run(
[
"--prediction_path=some/path",
]
)
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
mlengine_prediction_summary.run(
[
"--prediction_path=some/path",
@@ -61,7 +62,7 @@ def test_run_without_all_arguments_should_raise_exception(self):
)
def test_run_should_fail_for_invalid_encoded_fn(self):
- with self.assertRaises(binascii.Error):
+ with pytest.raises(binascii.Error):
mlengine_prediction_summary.run(
[
"--prediction_path=some/path",
@@ -74,7 +75,7 @@ def test_run_should_fail_if_enc_fn_is_not_callable(self):
non_callable_value = 1
fn_enc = base64.b64encode(dill.dumps(non_callable_value)).decode('utf-8')
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
mlengine_prediction_summary.run(
[
"--prediction_path=some/path",
diff --git a/tests/providers/google/common/auth_backend/test_google_openid.py b/tests/providers/google/common/auth_backend/test_google_openid.py
index 32f5f8f8e4780..edaa76b4a52af 100644
--- a/tests/providers/google/common/auth_backend/test_google_openid.py
+++ b/tests/providers/google/common/auth_backend/test_google_openid.py
@@ -63,10 +63,10 @@ def test_success(self, mock_verify_token):
response = test_client.get(
"/api/experimental/pools", headers={"Authorization": "bearer JWT_TOKEN"}
)
- self.assertEqual("test@fab.org", current_user.email)
+ assert "test@fab.org" == current_user.email
- self.assertEqual(200, response.status_code)
- self.assertIn("Default pool", str(response.json))
+ assert 200 == response.status_code
+ assert "Default pool" in str(response.json)
@parameterized.expand([("bearer",), ("JWT_TOKEN",), ("bearer ",)])
@mock.patch("google.oauth2.id_token.verify_token")
@@ -80,8 +80,8 @@ def test_malformed_headers(self, auth_header, mock_verify_token):
with self.app.test_client() as test_client:
response = test_client.get("/api/experimental/pools", headers={"Authorization": auth_header})
- self.assertEqual(403, response.status_code)
- self.assertEqual("Forbidden", response.data.decode())
+ assert 403 == response.status_code
+ assert "Forbidden" == response.data.decode()
@mock.patch("google.oauth2.id_token.verify_token")
def test_invalid_iss_in_jwt_token(self, mock_verify_token):
@@ -96,8 +96,8 @@ def test_invalid_iss_in_jwt_token(self, mock_verify_token):
"/api/experimental/pools", headers={"Authorization": "bearer JWT_TOKEN"}
)
- self.assertEqual(403, response.status_code)
- self.assertEqual("Forbidden", response.data.decode())
+ assert 403 == response.status_code
+ assert "Forbidden" == response.data.decode()
@mock.patch("google.oauth2.id_token.verify_token")
def test_user_not_exists(self, mock_verify_token):
@@ -112,16 +112,16 @@ def test_user_not_exists(self, mock_verify_token):
"/api/experimental/pools", headers={"Authorization": "bearer JWT_TOKEN"}
)
- self.assertEqual(403, response.status_code)
- self.assertEqual("Forbidden", response.data.decode())
+ assert 403 == response.status_code
+ assert "Forbidden" == response.data.decode()
@conf_vars({("api", "auth_backend"): "airflow.providers.google.common.auth_backend.google_openid"})
def test_missing_id_token(self):
with self.app.test_client() as test_client:
response = test_client.get("/api/experimental/pools")
- self.assertEqual(403, response.status_code)
- self.assertEqual("Forbidden", response.data.decode())
+ assert 403 == response.status_code
+ assert "Forbidden" == response.data.decode()
@conf_vars({("api", "auth_backend"): "airflow.providers.google.common.auth_backend.google_openid"})
@mock.patch("google.oauth2.id_token.verify_token")
@@ -133,5 +133,5 @@ def test_invalid_id_token(self, mock_verify_token):
"/api/experimental/pools", headers={"Authorization": "bearer JWT_TOKEN"}
)
- self.assertEqual(403, response.status_code)
- self.assertEqual("Forbidden", response.data.decode())
+ assert 403 == response.status_code
+ assert "Forbidden" == response.data.decode()
diff --git a/tests/providers/google/common/hooks/test_base_google.py b/tests/providers/google/common/hooks/test_base_google.py
index 70c5d6edd0bf8..d59bd4d645640 100644
--- a/tests/providers/google/common/hooks/test_base_google.py
+++ b/tests/providers/google/common/hooks/test_base_google.py
@@ -24,6 +24,7 @@
from unittest import mock
import google.auth
+import pytest
import tenacity
from google.auth.environment_vars import CREDENTIALS
from google.auth.exceptions import GoogleAuthError
@@ -74,17 +75,17 @@ def _retryable_test_with_temporary_quota_retry(thing):
class QuotaRetryTestCase(unittest.TestCase): # ptlint: disable=invalid-name
def test_do_nothing_on_non_error(self):
result = _retryable_test_with_temporary_quota_retry(lambda: 42)
- self.assertTrue(result, 42)
+ assert result, 42
def test_retry_on_exception(self):
message = "POST https://translation.googleapis.com/language/translate/v2: User Rate Limit Exceeded"
errors = [mock.MagicMock(details=mock.PropertyMock(return_value='userRateLimitExceeded'))]
custom_fn = NoForbiddenAfterCount(count=5, message=message, errors=errors)
_retryable_test_with_temporary_quota_retry(custom_fn)
- self.assertEqual(5, custom_fn.counter)
+ assert 5 == custom_fn.counter
def test_raise_exception_on_non_quota_exception(self):
- with self.assertRaisesRegex(Forbidden, "Daily Limit Exceeded"):
+ with pytest.raises(Forbidden, match="Daily Limit Exceeded"):
message = "POST https://translation.googleapis.com/language/translate/v2: Daily Limit Exceeded"
errors = [mock.MagicMock(details=mock.PropertyMock(return_value='dailyLimitExceeded'))]
@@ -132,13 +133,11 @@ def test_provided_project_id(self):
def test_restrict_positional_arguments(self):
gcp_hook = FallbackToDefaultProjectIdFixtureClass(321)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
gcp_hook.method(123)
- self.assertEqual(
- str(cm.exception), "You must use keyword arguments in this methods rather than positional"
- )
- self.assertEqual(gcp_hook.mock.call_count, 0)
+ assert str(ctx.value) == "You must use keyword arguments in this methods rather than positional"
+ assert gcp_hook.mock.call_count == 0
ENV_VALUE = "/tmp/a"
@@ -161,11 +160,11 @@ def test_provide_gcp_credential_file_decorator_key_path_and_keyfile_dict(self):
@hook.GoogleBaseHook.provide_gcp_credential_file
def assert_gcp_credential_file_in_env(_):
- self.assertEqual(os.environ[CREDENTIALS], key_path)
+ assert os.environ[CREDENTIALS] == key_path
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowException,
- 'The `keyfile_dict` and `key_path` fields are mutually exclusive. '
+ match='The `keyfile_dict` and `key_path` fields are mutually exclusive. '
'Please provide only one value.',
):
assert_gcp_credential_file_in_env(self.instance)
@@ -176,7 +175,7 @@ def test_provide_gcp_credential_file_decorator_key_path(self):
@hook.GoogleBaseHook.provide_gcp_credential_file
def assert_gcp_credential_file_in_env(_):
- self.assertEqual(os.environ[CREDENTIALS], key_path)
+ assert os.environ[CREDENTIALS] == key_path
assert_gcp_credential_file_in_env(self.instance)
@@ -192,8 +191,8 @@ def test_provide_gcp_credential_file_decorator_key_content(self, mock_file):
@hook.GoogleBaseHook.provide_gcp_credential_file
def assert_gcp_credential_file_in_env(_):
- self.assertEqual(os.environ[CREDENTIALS], file_name)
- self.assertEqual(file_content, string_file.getvalue())
+ assert os.environ[CREDENTIALS] == file_name
+ assert file_content == string_file.getvalue()
assert_gcp_credential_file_in_env(self.instance)
@@ -204,10 +203,10 @@ def test_provide_gcp_credential_keep_environment(self):
@hook.GoogleBaseHook.provide_gcp_credential_file
def assert_gcp_credential_file_in_env(_):
- self.assertEqual(os.environ[CREDENTIALS], key_path)
+ assert os.environ[CREDENTIALS] == key_path
assert_gcp_credential_file_in_env(self.instance)
- self.assertEqual(os.environ[CREDENTIALS], ENV_VALUE)
+ assert os.environ[CREDENTIALS] == ENV_VALUE
@mock.patch.dict(os.environ, {CREDENTIALS: ENV_VALUE})
def test_provide_gcp_credential_keep_environment_when_exception(self):
@@ -218,10 +217,10 @@ def test_provide_gcp_credential_keep_environment_when_exception(self):
def assert_gcp_credential_file_in_env(_):
raise Exception()
- with self.assertRaises(Exception):
+ with pytest.raises(Exception):
assert_gcp_credential_file_in_env(self.instance)
- self.assertEqual(os.environ[CREDENTIALS], ENV_VALUE)
+ assert os.environ[CREDENTIALS] == ENV_VALUE
@mock.patch.dict(os.environ, clear=True)
def test_provide_gcp_credential_keep_clear_environment(self):
@@ -230,10 +229,10 @@ def test_provide_gcp_credential_keep_clear_environment(self):
@hook.GoogleBaseHook.provide_gcp_credential_file
def assert_gcp_credential_file_in_env(_):
- self.assertEqual(os.environ[CREDENTIALS], key_path)
+ assert os.environ[CREDENTIALS] == key_path
assert_gcp_credential_file_in_env(self.instance)
- self.assertNotIn(CREDENTIALS, os.environ)
+ assert CREDENTIALS not in os.environ
@mock.patch.dict(os.environ, clear=True)
def test_provide_gcp_credential_keep_clear_environment_when_exception(self):
@@ -244,10 +243,10 @@ def test_provide_gcp_credential_keep_clear_environment_when_exception(self):
def assert_gcp_credential_file_in_env(_):
raise Exception()
- with self.assertRaises(Exception):
+ with pytest.raises(Exception):
assert_gcp_credential_file_in_env(self.instance)
- self.assertNotIn(CREDENTIALS, os.environ)
+ assert CREDENTIALS not in os.environ
class TestProvideGcpCredentialFileAsContext(unittest.TestCase):
@@ -263,7 +262,7 @@ def test_provide_gcp_credential_file_decorator_key_path(self):
self.instance.extras = {'extra__google_cloud_platform__key_path': key_path}
with self.instance.provide_gcp_credential_file_as_context():
- self.assertEqual(os.environ[CREDENTIALS], key_path)
+ assert os.environ[CREDENTIALS] == key_path
@mock.patch('tempfile.NamedTemporaryFile')
def test_provide_gcp_credential_file_decorator_key_content(self, mock_file):
@@ -276,8 +275,8 @@ def test_provide_gcp_credential_file_decorator_key_content(self, mock_file):
mock_file_handler.write = string_file.write
with self.instance.provide_gcp_credential_file_as_context():
- self.assertEqual(os.environ[CREDENTIALS], file_name)
- self.assertEqual(file_content, string_file.getvalue())
+ assert os.environ[CREDENTIALS] == file_name
+ assert file_content == string_file.getvalue()
@mock.patch.dict(os.environ, {CREDENTIALS: ENV_VALUE})
def test_provide_gcp_credential_keep_environment(self):
@@ -285,20 +284,20 @@ def test_provide_gcp_credential_keep_environment(self):
self.instance.extras = {'extra__google_cloud_platform__key_path': key_path}
with self.instance.provide_gcp_credential_file_as_context():
- self.assertEqual(os.environ[CREDENTIALS], key_path)
+ assert os.environ[CREDENTIALS] == key_path
- self.assertEqual(os.environ[CREDENTIALS], ENV_VALUE)
+ assert os.environ[CREDENTIALS] == ENV_VALUE
@mock.patch.dict(os.environ, {CREDENTIALS: ENV_VALUE})
def test_provide_gcp_credential_keep_environment_when_exception(self):
key_path = '/test/key-path'
self.instance.extras = {'extra__google_cloud_platform__key_path': key_path}
- with self.assertRaises(Exception):
+ with pytest.raises(Exception):
with self.instance.provide_gcp_credential_file_as_context():
raise Exception()
- self.assertEqual(os.environ[CREDENTIALS], ENV_VALUE)
+ assert os.environ[CREDENTIALS] == ENV_VALUE
@mock.patch.dict(os.environ, clear=True)
def test_provide_gcp_credential_keep_clear_environment(self):
@@ -306,20 +305,20 @@ def test_provide_gcp_credential_keep_clear_environment(self):
self.instance.extras = {'extra__google_cloud_platform__key_path': key_path}
with self.instance.provide_gcp_credential_file_as_context():
- self.assertEqual(os.environ[CREDENTIALS], key_path)
+ assert os.environ[CREDENTIALS] == key_path
- self.assertNotIn(CREDENTIALS, os.environ)
+ assert CREDENTIALS not in os.environ
@mock.patch.dict(os.environ, clear=True)
def test_provide_gcp_credential_keep_clear_environment_when_exception(self):
key_path = '/test/key-path'
self.instance.extras = {'extra__google_cloud_platform__key_path': key_path}
- with self.assertRaises(Exception):
+ with pytest.raises(Exception):
with self.instance.provide_gcp_credential_file_as_context():
raise Exception()
- self.assertNotIn(CREDENTIALS, os.environ)
+ assert CREDENTIALS not in os.environ
class TestGoogleBaseHook(unittest.TestCase):
@@ -338,7 +337,7 @@ def test_get_credentials_and_project_id_with_default_auth(self, mock_get_creds_a
target_principal=None,
delegates=None,
)
- self.assertEqual(('CREDENTIALS', 'PROJECT_ID'), result)
+ assert ('CREDENTIALS', 'PROJECT_ID') == result
@mock.patch(MODULE_NAME + '.get_credentials_and_project_id')
def test_get_credentials_and_project_id_with_service_account_file(self, mock_get_creds_and_proj_id):
@@ -354,16 +353,16 @@ def test_get_credentials_and_project_id_with_service_account_file(self, mock_get
target_principal=None,
delegates=None,
)
- self.assertEqual((mock_credentials, 'PROJECT_ID'), result)
+ assert (mock_credentials, 'PROJECT_ID') == result
def test_get_credentials_and_project_id_with_service_account_file_and_p12_key(self):
self.instance.extras = {'extra__google_cloud_platform__key_path': "KEY_PATH.p12"}
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.instance._get_credentials_and_project_id()
def test_get_credentials_and_project_id_with_service_account_file_and_unknown_key(self):
self.instance.extras = {'extra__google_cloud_platform__key_path': "KEY_PATH.unknown"}
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.instance._get_credentials_and_project_id()
@mock.patch(MODULE_NAME + '.get_credentials_and_project_id')
@@ -381,7 +380,7 @@ def test_get_credentials_and_project_id_with_service_account_info(self, mock_get
target_principal=None,
delegates=None,
)
- self.assertEqual((mock_credentials, 'PROJECT_ID'), result)
+ assert (mock_credentials, 'PROJECT_ID') == result
@mock.patch(MODULE_NAME + '.get_credentials_and_project_id')
def test_get_credentials_and_project_id_with_default_auth_and_delegate(self, mock_get_creds_and_proj_id):
@@ -398,7 +397,7 @@ def test_get_credentials_and_project_id_with_default_auth_and_delegate(self, moc
target_principal=None,
delegates=None,
)
- self.assertEqual((mock_credentials, "PROJECT_ID"), result)
+ assert (mock_credentials, "PROJECT_ID") == result
@mock.patch('google.auth.default')
def test_get_credentials_and_project_id_with_default_auth_and_unsupported_delegate(
@@ -408,9 +407,9 @@ def test_get_credentials_and_project_id_with_default_auth_and_unsupported_delega
mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials)
mock_auth_default.return_value = (mock_credentials, "PROJECT_ID")
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowException,
- re.escape(
+ match=re.escape(
"The `delegate_to` parameter cannot be used here as the current authentication method "
"does not support account impersonate. Please use service-account for authorization."
),
@@ -431,7 +430,7 @@ def test_get_credentials_and_project_id_with_default_auth_and_overridden_project
target_principal=None,
delegates=None,
)
- self.assertEqual(("CREDENTIALS", 'SECOND_PROJECT_ID'), result)
+ assert ("CREDENTIALS", 'SECOND_PROJECT_ID') == result
def test_get_credentials_and_project_id_with_mutually_exclusive_configuration(
self,
@@ -441,8 +440,9 @@ def test_get_credentials_and_project_id_with_mutually_exclusive_configuration(
'extra__google_cloud_platform__key_path': "KEY_PATH",
'extra__google_cloud_platform__keyfile_dict': '{"KEY": "VALUE"}',
}
- with self.assertRaisesRegex(
- AirflowException, re.escape('The `keyfile_dict` and `key_path` fields are mutually exclusive.')
+ with pytest.raises(
+ AirflowException,
+ match=re.escape('The `keyfile_dict` and `key_path` fields are mutually exclusive.'),
):
self.instance._get_credentials_and_project_id()
@@ -452,7 +452,7 @@ def test_get_credentials_and_project_id_with_invalid_keyfile_dict(
self.instance.extras = {
'extra__google_cloud_platform__keyfile_dict': 'INVALID_DICT',
}
- with self.assertRaisesRegex(AirflowException, re.escape('Invalid key JSON.')):
+ with pytest.raises(AirflowException, match=re.escape('Invalid key JSON.')):
self.instance._get_credentials_and_project_id()
@unittest.skipIf(
@@ -479,8 +479,8 @@ def test_default_creds_with_scopes(self):
return
scopes = credentials.scopes
- self.assertIn('https://www.googleapis.com/auth/bigquery', scopes)
- self.assertIn('https://www.googleapis.com/auth/devstorage.read_only', scopes)
+ assert 'https://www.googleapis.com/auth/bigquery' in scopes
+ assert 'https://www.googleapis.com/auth/devstorage.read_only' in scopes
@unittest.skipIf(
not default_creds_available, 'Default Google Cloud credentials not available to run tests'
@@ -496,7 +496,7 @@ def test_default_creds_no_scopes(self):
return
scopes = credentials.scopes
- self.assertEqual(tuple(_DEFAULT_SCOPES), tuple(scopes))
+ assert tuple(_DEFAULT_SCOPES) == tuple(scopes)
def test_provide_gcp_credential_file_decorator_key_path(self):
key_path = '/test/key-path'
@@ -504,7 +504,7 @@ def test_provide_gcp_credential_file_decorator_key_path(self):
@hook.GoogleBaseHook.provide_gcp_credential_file
def assert_gcp_credential_file_in_env(hook_instance): # pylint: disable=unused-argument
- self.assertEqual(os.environ[CREDENTIALS], key_path)
+ assert os.environ[CREDENTIALS] == key_path
assert_gcp_credential_file_in_env(self.instance)
@@ -520,8 +520,8 @@ def test_provide_gcp_credential_file_decorator_key_content(self, mock_file):
@hook.GoogleBaseHook.provide_gcp_credential_file
def assert_gcp_credential_file_in_env(hook_instance): # pylint: disable=unused-argument
- self.assertEqual(os.environ[CREDENTIALS], file_name)
- self.assertEqual(file_content, string_file.getvalue())
+ assert os.environ[CREDENTIALS] == file_name
+ assert file_content == string_file.getvalue()
assert_gcp_credential_file_in_env(self.instance)
@@ -538,18 +538,15 @@ def test_provided_scopes(self):
),
}
- self.assertEqual(
- self.instance.scopes,
- [
- 'https://www.googleapis.com/auth/bigquery',
- 'https://www.googleapis.com/auth/devstorage.read_only',
- ],
- )
+ assert self.instance.scopes == [
+ 'https://www.googleapis.com/auth/bigquery',
+ 'https://www.googleapis.com/auth/devstorage.read_only',
+ ]
def test_default_scopes(self):
self.instance.extras = {'extra__google_cloud_platform__project': default_project}
- self.assertEqual(self.instance.scopes, ('https://www.googleapis.com/auth/cloud-platform',))
+ assert self.instance.scopes == ('https://www.googleapis.com/auth/cloud-platform',)
@mock.patch("airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection")
def test_num_retries_is_not_none_by_default(self, get_con_mock):
@@ -558,7 +555,7 @@ def test_num_retries_is_not_none_by_default(self, get_con_mock):
should not be None
"""
get_con_mock.return_value.extra_dejson = {"extra__google_cloud_platform__num_retries": None}
- self.assertEqual(self.instance.num_retries, 5)
+ assert self.instance.num_retries == 5
@mock.patch("airflow.providers.google.common.hooks.base_google.build_http")
@mock.patch("airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials")
@@ -582,8 +579,8 @@ def test_authorize_assert_user_agent_is_sent(self, mock_get_credentials, mock_ht
method='GET',
redirections=5,
)
- self.assertEqual(response, new_response)
- self.assertEqual(content, new_content)
+ assert response == new_response
+ assert content == new_content
@mock.patch("airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials")
def test_authorize_assert_http_308_is_excluded(self, mock_get_credentials):
@@ -591,7 +588,7 @@ def test_authorize_assert_http_308_is_excluded(self, mock_get_credentials):
Verify that 308 status code is excluded from httplib2's redirect codes
"""
http_authorized = self.instance._authorize().http
- self.assertTrue(308 not in http_authorized.redirect_codes)
+ assert 308 not in http_authorized.redirect_codes
@mock.patch("airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials")
def test_authorize_assert_http_timeout_is_present(self, mock_get_credentials):
@@ -599,7 +596,7 @@ def test_authorize_assert_http_timeout_is_present(self, mock_get_credentials):
Verify that http client has a timeout set
"""
http_authorized = self.instance._authorize().http
- self.assertNotEqual(http_authorized.timeout, None)
+ assert http_authorized.timeout is not None
@parameterized.expand(
[
@@ -634,7 +631,7 @@ def test_get_credentials_and_project_id_with_impersonation_chain(
target_principal=target_principal,
delegates=delegates,
)
- self.assertEqual((mock_credentials, PROJECT_ID), result)
+ assert (mock_credentials, PROJECT_ID) == result
class TestProvideAuthorizedGcloud(unittest.TestCase):
@@ -658,13 +655,13 @@ def test_provide_authorized_gcloud_key_path_and_keyfile_dict(self, mock_check_ou
'extra__google_cloud_platform__keyfile_dict': '{"foo": "bar"}',
}
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowException,
- 'The `keyfile_dict` and `key_path` fields are mutually exclusive. '
+ match='The `keyfile_dict` and `key_path` fields are mutually exclusive. '
'Please provide only one value.',
):
with self.instance.provide_authorized_gcloud():
- self.assertEqual(os.environ[CREDENTIALS], key_path)
+ assert os.environ[CREDENTIALS] == key_path
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -677,7 +674,7 @@ def test_provide_authorized_gcloud_key_path(self, mock_check_output, mock_projec
self.instance.extras = {'extra__google_cloud_platform__key_path': key_path}
with self.instance.provide_authorized_gcloud():
- self.assertEqual(os.environ[CREDENTIALS], key_path)
+ assert os.environ[CREDENTIALS] == key_path
mock_check_output.has_calls(
mock.call(['gcloud', 'config', 'set', 'core/project', 'PROJECT_ID']),
@@ -701,7 +698,7 @@ def test_provide_authorized_gcloud_keyfile_dict(self, mock_file, mock_check_outp
mock_file_handler.write = string_file.write
with self.instance.provide_authorized_gcloud():
- self.assertEqual(os.environ[CREDENTIALS], file_name)
+ assert os.environ[CREDENTIALS] == file_name
mock_check_output.has_calls(
[
@@ -755,8 +752,8 @@ def test_should_return_int_when_set_int_via_connection(self):
'extra__google_cloud_platform__num_retries': 10,
}
- self.assertIsInstance(instance.num_retries, int)
- self.assertEqual(10, instance.num_retries)
+ assert isinstance(instance.num_retries, int)
+ assert 10 == instance.num_retries
@mock.patch.dict(
'os.environ',
@@ -766,7 +763,7 @@ def test_should_return_int_when_set_int_via_connection(self):
)
def test_should_return_int_when_set_via_env_var(self):
instance = hook.GoogleBaseHook(gcp_conn_id="google_cloud_default")
- self.assertIsInstance(instance.num_retries, int)
+ assert isinstance(instance.num_retries, int)
@mock.patch.dict(
'os.environ',
@@ -776,10 +773,8 @@ def test_should_return_int_when_set_via_env_var(self):
)
def test_should_raise_when_invalid_value_via_env_var(self):
instance = hook.GoogleBaseHook(gcp_conn_id="google_cloud_default")
- with self.assertRaisesRegex(
- AirflowException, re.escape("The num_retries field should be a integer.")
- ):
- self.assertIsInstance(instance.num_retries, int)
+ with pytest.raises(AirflowException, match=re.escape("The num_retries field should be a integer.")):
+ assert isinstance(instance.num_retries, int)
@mock.patch.dict(
'os.environ',
@@ -789,5 +784,5 @@ def test_should_raise_when_invalid_value_via_env_var(self):
)
def test_should_fallback_when_empty_string_in_env_var(self):
instance = hook.GoogleBaseHook(gcp_conn_id="google_cloud_default")
- self.assertIsInstance(instance.num_retries, int)
- self.assertEqual(5, instance.num_retries)
+ assert isinstance(instance.num_retries, int)
+ assert 5 == instance.num_retries
diff --git a/tests/providers/google/common/utils/test_id_token_credentials.py b/tests/providers/google/common/utils/test_id_token_credentials.py
index debfc5a9beeda..2719fbb55042f 100644
--- a/tests/providers/google/common/utils/test_id_token_credentials.py
+++ b/tests/providers/google/common/utils/test_id_token_credentials.py
@@ -21,6 +21,7 @@
import unittest
from unittest import mock
+import pytest
from google.auth import exceptions
from google.auth.environment_vars import CREDENTIALS
@@ -36,12 +37,12 @@ def test_should_use_id_token_from_parent_credentials(self):
type(parent_credentials).id_token = mock.PropertyMock(side_effect=["ID_TOKEN1", "ID_TOKEN2"])
creds = IDTokenCredentialsAdapter(credentials=parent_credentials)
- self.assertEqual(creds.token, "ID_TOKEN1")
+ assert creds.token == "ID_TOKEN1"
request_adapter = mock.MagicMock()
creds.refresh(request_adapter)
- self.assertEqual(creds.token, "ID_TOKEN2")
+ assert creds.token == "ID_TOKEN2"
class TestGetDefaultIdTokenCredentials(unittest.TestCase):
@@ -57,9 +58,9 @@ class TestGetDefaultIdTokenCredentials(unittest.TestCase):
def test_should_raise_exception(self, mock_metadata_ping, mock_gcloud_sdk_path):
if CREDENTIALS in os.environ:
del os.environ[CREDENTIALS]
- with self.assertRaisesRegex(
+ with pytest.raises(
exceptions.DefaultCredentialsError,
- re.escape(
+ match=re.escape(
"Could not automatically determine credentials. Please set GOOGLE_APPLICATION_CREDENTIALS "
"or explicitly create credentials and re-run the application. For more information, please "
"see https://cloud.google.com/docs/authentication/getting-started"
@@ -83,9 +84,7 @@ def test_should_support_metadata_credentials(self, credentials, mock_metadata_pi
if CREDENTIALS in os.environ:
del os.environ[CREDENTIALS]
- self.assertEqual(
- credentials.return_value, get_default_id_token_credentials(target_audience="example.org")
- )
+ assert credentials.return_value == get_default_id_token_credentials(target_audience="example.org")
@mock.patch.dict("os.environ")
@mock.patch(
@@ -107,8 +106,8 @@ def test_should_support_user_credentials_from_gcloud(self, mock_gcloud_sdk_path)
del os.environ[CREDENTIALS]
credentials = get_default_id_token_credentials(target_audience="example.org")
- self.assertIsInstance(credentials, IDTokenCredentialsAdapter)
- self.assertEqual(credentials.credentials.client_secret, "CLIENT_SECRET")
+ assert isinstance(credentials, IDTokenCredentialsAdapter)
+ assert credentials.credentials.client_secret == "CLIENT_SECRET"
@mock.patch.dict("os.environ")
@mock.patch(
@@ -137,7 +136,7 @@ def test_should_support_service_account_from_gcloud(self, mock_gcloud_sdk_path,
del os.environ[CREDENTIALS]
credentials = get_default_id_token_credentials(target_audience="example.org")
- self.assertEqual(credentials.service_account_email, "CLIENT_EMAIL")
+ assert credentials.service_account_email == "CLIENT_EMAIL"
@mock.patch.dict("os.environ")
@mock.patch(
@@ -164,4 +163,4 @@ def test_should_support_service_account_from_env(self, mock_gcloud_sdk_path):
os.environ[CREDENTIALS] = __file__
credentials = get_default_id_token_credentials(target_audience="example.org")
- self.assertEqual(credentials.service_account_email, "CLIENT_EMAIL")
+ assert credentials.service_account_email == "CLIENT_EMAIL"
diff --git a/tests/providers/google/firebase/hooks/test_firestore.py b/tests/providers/google/firebase/hooks/test_firestore.py
index 9703c4fccba95..507cddc1a8098 100644
--- a/tests/providers/google/firebase/hooks/test_firestore.py
+++ b/tests/providers/google/firebase/hooks/test_firestore.py
@@ -23,6 +23,8 @@
from unittest import mock
from unittest.mock import PropertyMock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.google.firebase.hooks.firestore import CloudFirestoreHook
from tests.providers.google.cloud.utils.base_gcp_mock import (
@@ -64,8 +66,8 @@ def test_client_creation(self, mock_build_from_document, mock_build, mock_author
mock_build_from_document.assert_called_once_with(
mock_build.return_value._rootDesc, http=mock_authorize.return_value
)
- self.assertEqual(mock_build_from_document.return_value, result)
- self.assertEqual(self.hook._conn, result)
+ assert mock_build_from_document.return_value == result
+ assert self.hook._conn == result
@mock.patch("airflow.providers.google.firebase.hooks.firestore.CloudFirestoreHook.get_conn")
def test_immediately_complete(self, get_conn_mock):
@@ -120,7 +122,7 @@ def test_error_operation(self, _, get_conn_mock):
execute_mock = mock.Mock(**{"side_effect": [TEST_WAITING_OPERATION, TEST_ERROR_OPERATION]})
mock_operation_get.return_value.execute = execute_mock
- with self.assertRaisesRegex(AirflowException, "error"):
+ with pytest.raises(AirflowException, match="error"):
self.hook.export_documents(body=EXPORT_DOCUMENT_BODY, project_id=TEST_PROJECT_ID)
@@ -143,8 +145,8 @@ def test_client_creation(self, mock_build_from_document, mock_build, mock_author
mock_build_from_document.assert_called_once_with(
mock_build.return_value._rootDesc, http=mock_authorize.return_value
)
- self.assertEqual(mock_build_from_document.return_value, result)
- self.assertEqual(self.hook._conn, result)
+ assert mock_build_from_document.return_value == result
+ assert self.hook._conn == result
@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
@@ -214,7 +216,7 @@ def test_error_operation(self, _, get_conn_mock, mock_project_id):
execute_mock = mock.Mock(**{"side_effect": [TEST_WAITING_OPERATION, TEST_ERROR_OPERATION]})
mock_operation_get.return_value.execute = execute_mock
- with self.assertRaisesRegex(AirflowException, "error"):
+ with pytest.raises(AirflowException, match="error"):
self.hook.export_documents(body=EXPORT_DOCUMENT_BODY)
@@ -235,11 +237,10 @@ def setUp(self):
)
@mock.patch("airflow.providers.google.firebase.hooks.firestore.CloudFirestoreHook.get_conn")
def test_create_build(self, mock_get_conn, mock_project_id):
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.hook.export_documents(body={})
- self.assertEqual(
+ assert (
"The project id must be passed either as keyword project_id parameter or as project_id extra in "
- "Google Cloud connection definition. Both are not set!",
- str(e.exception),
+ "Google Cloud connection definition. Both are not set!" == str(ctx.value)
)
diff --git a/tests/providers/google/marketing_platform/hooks/test_analytics.py b/tests/providers/google/marketing_platform/hooks/test_analytics.py
index abd20b0a01e97..1623d92b754d1 100644
--- a/tests/providers/google/marketing_platform/hooks/test_analytics.py
+++ b/tests/providers/google/marketing_platform/hooks/test_analytics.py
@@ -52,7 +52,7 @@ def test_init(self, mock_base_init):
delegate_to=DELEGATE_TO,
impersonation_chain=IMPERSONATION_CHAIN,
)
- self.assertEqual(hook.api_version, API_VERSION)
+ assert hook.api_version == API_VERSION
@mock.patch("airflow.providers.google.marketing_platform.hooks.analytics.GoogleAnalyticsHook._authorize")
@mock.patch("airflow.providers.google.marketing_platform.hooks.analytics.build")
@@ -64,7 +64,7 @@ def test_gen_conn(self, mock_build, mock_authorize):
http=mock_authorize.return_value,
cache_discovery=False,
)
- self.assertEqual(mock_build.return_value, result)
+ assert mock_build.return_value == result
@mock.patch("airflow.providers.google.marketing_platform.hooks.analytics.GoogleAnalyticsHook.get_conn")
def test_list_accounts(self, get_conn_mock):
@@ -73,7 +73,7 @@ def test_list_accounts(self, get_conn_mock):
mock_execute = mock_list.return_value.execute
mock_execute.return_value = {"items": ["a", "b"], "totalResults": 2}
list_accounts = self.hook.list_accounts()
- self.assertEqual(list_accounts, ["a", "b"])
+ assert list_accounts == ["a", "b"]
@mock.patch("airflow.providers.google.marketing_platform.hooks.analytics.GoogleAnalyticsHook.get_conn")
def test_list_accounts_for_multiple_pages(self, get_conn_mock):
@@ -85,7 +85,7 @@ def test_list_accounts_for_multiple_pages(self, get_conn_mock):
{"items": ["b"], "totalResults": 2},
]
list_accounts = self.hook.list_accounts()
- self.assertEqual(list_accounts, ["a", "b"])
+ assert list_accounts == ["a", "b"]
@mock.patch("airflow.providers.google.marketing_platform.hooks.analytics.GoogleAnalyticsHook.get_conn")
def test_get_ad_words_links_call(self, get_conn_mock):
@@ -116,7 +116,7 @@ def test_list_ad_words_links(self, get_conn_mock):
mock_execute = mock_list.return_value.execute
mock_execute.return_value = {"items": ["a", "b"], "totalResults": 2}
list_ads_links = self.hook.list_ad_words_links(account_id=ACCOUNT_ID, web_property_id=WEB_PROPERTY_ID)
- self.assertEqual(list_ads_links, ["a", "b"])
+ assert list_ads_links == ["a", "b"]
@mock.patch("airflow.providers.google.marketing_platform.hooks.analytics.GoogleAnalyticsHook.get_conn")
def test_list_ad_words_links_for_multiple_pages(self, get_conn_mock):
@@ -128,7 +128,7 @@ def test_list_ad_words_links_for_multiple_pages(self, get_conn_mock):
{"items": ["b"], "totalResults": 2},
]
list_ads_links = self.hook.list_ad_words_links(account_id=ACCOUNT_ID, web_property_id=WEB_PROPERTY_ID)
- self.assertEqual(list_ads_links, ["a", "b"])
+ assert list_ads_links == ["a", "b"]
@mock.patch("airflow.providers.google.marketing_platform.hooks.analytics.GoogleAnalyticsHook.get_conn")
@mock.patch("airflow.providers.google.marketing_platform.hooks.analytics.MediaFileUpload")
@@ -184,4 +184,4 @@ def test_list_upload(self, get_conn_mock):
web_property_id=WEB_PROPERTY_ID,
custom_data_source_id=DATA_SOURCE,
)
- self.assertEqual(result, ["a", "b"])
+ assert result == ["a", "b"]
diff --git a/tests/providers/google/marketing_platform/hooks/test_campaign_manager.py b/tests/providers/google/marketing_platform/hooks/test_campaign_manager.py
index 3bdf5da0274e6..9eae66a9cf21d 100644
--- a/tests/providers/google/marketing_platform/hooks/test_campaign_manager.py
+++ b/tests/providers/google/marketing_platform/hooks/test_campaign_manager.py
@@ -51,7 +51,7 @@ def test_gen_conn(self, mock_build, mock_authorize):
http=mock_authorize.return_value,
cache_discovery=False,
)
- self.assertEqual(mock_build.return_value, result)
+ assert mock_build.return_value == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -69,7 +69,7 @@ def test_delete_report(self, get_conn_mock):
profileId=PROFILE_ID, reportId=REPORT_ID
)
- self.assertEqual(return_value, result)
+ assert return_value == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -90,7 +90,7 @@ def test_get_report(self, get_conn_mock):
profileId=PROFILE_ID, reportId=REPORT_ID, fileId=file_id
)
- self.assertEqual(return_value, result)
+ assert return_value == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -110,7 +110,7 @@ def test_get_report_file(self, get_conn_mock):
profileId=PROFILE_ID, reportId=REPORT_ID, fileId=file_id
)
- self.assertEqual(return_value, result)
+ assert return_value == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -130,7 +130,7 @@ def test_insert_report(self, get_conn_mock):
profileId=PROFILE_ID, body=report
)
- self.assertEqual(return_value, result)
+ assert return_value == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -171,7 +171,7 @@ def test_list_reports(self, get_conn_mock):
sortOrder=sort_order,
)
- self.assertEqual(items * 4, result)
+ assert items * 4 == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -189,7 +189,7 @@ def test_patch_report(self, get_conn_mock):
profileId=PROFILE_ID, reportId=REPORT_ID, body=update_mask
)
- self.assertEqual(return_value, result)
+ assert return_value == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -207,7 +207,7 @@ def test_run_report(self, get_conn_mock):
profileId=PROFILE_ID, reportId=REPORT_ID, synchronous=synchronous
)
- self.assertEqual(return_value, result)
+ assert return_value == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -225,7 +225,7 @@ def test_update_report(self, get_conn_mock):
profileId=PROFILE_ID, reportId=REPORT_ID
)
- self.assertEqual(return_value, result)
+ assert return_value == result
@mock.patch(
"airflow.providers.google.marketing_platform."
@@ -264,7 +264,7 @@ def test_conversion_batch_insert(self, batch_request_mock, get_conn_mock):
profileId=PROFILE_ID, body=batch_request_mock.return_value
)
- self.assertEqual(return_value, result)
+ assert return_value == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -303,4 +303,4 @@ def test_conversions_batch_update(self, batch_request_mock, get_conn_mock):
profileId=PROFILE_ID, body=batch_request_mock.return_value
)
- self.assertEqual(return_value, result)
+ assert return_value == result
diff --git a/tests/providers/google/marketing_platform/hooks/test_display_video.py b/tests/providers/google/marketing_platform/hooks/test_display_video.py
index 2a14ac55941d8..34b73990e48e6 100644
--- a/tests/providers/google/marketing_platform/hooks/test_display_video.py
+++ b/tests/providers/google/marketing_platform/hooks/test_display_video.py
@@ -45,7 +45,7 @@ def test_gen_conn(self, mock_build, mock_authorize):
http=mock_authorize.return_value,
cache_discovery=False,
)
- self.assertEqual(mock_build.return_value, result)
+ assert mock_build.return_value == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -60,7 +60,7 @@ def test_get_conn_to_display_video(self, mock_build, mock_authorize):
http=mock_authorize.return_value,
cache_discovery=False,
)
- self.assertEqual(mock_build.return_value, result)
+ assert mock_build.return_value == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -78,7 +78,7 @@ def test_create_query(self, get_conn_mock):
get_conn_mock.return_value.queries.return_value.createquery.assert_called_once_with(body=body)
- self.assertEqual(return_value, result)
+ assert return_value == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -112,7 +112,7 @@ def test_get_query(self, get_conn_mock):
get_conn_mock.return_value.queries.return_value.getquery.assert_called_once_with(queryId=query_id)
- self.assertEqual(return_value, result)
+ assert return_value == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -129,7 +129,7 @@ def test_list_queries(self, get_conn_mock):
get_conn_mock.return_value.queries.return_value.listqueries.assert_called_once_with()
- self.assertEqual(queries, result)
+ assert queries == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -196,7 +196,7 @@ def test_download_line_items_should_return_equal_values(self, get_conn_mock):
# fmt: on
result = self.hook.download_line_items(request_body)
- self.assertEqual(line_item, result)
+ assert line_item == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -239,7 +239,7 @@ def test_upload_line_items_should_return_equal_values(self, get_conn_mock):
# fmt: on
result = self.hook.upload_line_items(line_items)
- self.assertEqual(return_value, result)
+ assert return_value == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -302,7 +302,7 @@ def test_create_sdf_download_tasks_return_equal_values(self, get_conn_to_display
# fmt: on
result = self.hook.create_sdf_download_operation(body_request=body_request)
- self.assertEqual(response, result)
+ assert response == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
@@ -346,7 +346,7 @@ def get_sdf_download_tasks_return_equal_values(self, get_conn_to_display_video):
result = self.hook.get_sdf_download_operation(operation_name=operation_name)
- self.assertEqual(operation_name, result)
+ assert operation_name == result
@mock.patch(
"airflow.providers.google.marketing_platform.hooks."
diff --git a/tests/providers/google/marketing_platform/hooks/test_search_ads.py b/tests/providers/google/marketing_platform/hooks/test_search_ads.py
index a3ad3d32abc83..07d477ae86090 100644
--- a/tests/providers/google/marketing_platform/hooks/test_search_ads.py
+++ b/tests/providers/google/marketing_platform/hooks/test_search_ads.py
@@ -42,7 +42,7 @@ def test_gen_conn(self, mock_build, mock_authorize):
http=mock_authorize.return_value,
cache_discovery=False,
)
- self.assertEqual(mock_build.return_value, result)
+ assert mock_build.return_value == result
@mock.patch("airflow.providers.google.marketing_platform.hooks.search_ads.GoogleSearchAdsHook.get_conn")
def test_insert(self, get_conn_mock):
@@ -57,7 +57,7 @@ def test_insert(self, get_conn_mock):
get_conn_mock.return_value.reports.return_value.request.assert_called_once_with(body=report)
- self.assertEqual(return_value, result)
+ assert return_value == result
@mock.patch("airflow.providers.google.marketing_platform.hooks.search_ads.GoogleSearchAdsHook.get_conn")
def test_get(self, get_conn_mock):
@@ -70,7 +70,7 @@ def test_get(self, get_conn_mock):
get_conn_mock.return_value.reports.return_value.get.assert_called_once_with(reportId=report_id)
- self.assertEqual(return_value, result)
+ assert return_value == result
@mock.patch("airflow.providers.google.marketing_platform.hooks.search_ads.GoogleSearchAdsHook.get_conn")
def test_get_file(self, get_conn_mock):
@@ -88,4 +88,4 @@ def test_get_file(self, get_conn_mock):
reportFragment=report_fragment, reportId=report_id
)
- self.assertEqual(return_value, result)
+ assert return_value == result
diff --git a/tests/providers/google/marketing_platform/sensors/test_campaign_manager.py b/tests/providers/google/marketing_platform/sensors/test_campaign_manager.py
index 5ea787c73729e..1f1f9398ab198 100644
--- a/tests/providers/google/marketing_platform/sensors/test_campaign_manager.py
+++ b/tests/providers/google/marketing_platform/sensors/test_campaign_manager.py
@@ -54,4 +54,4 @@ def test_execute(self, mock_base_op, hook_mock):
hook_mock.return_value.get_report.assert_called_once_with(
profile_id=profile_id, report_id=report_id, file_id=file_id
)
- self.assertTrue(result)
+ assert result
diff --git a/tests/providers/google/suite/hooks/test_drive.py b/tests/providers/google/suite/hooks/test_drive.py
index f6142c4aacf03..6e50df125182f 100644
--- a/tests/providers/google/suite/hooks/test_drive.py
+++ b/tests/providers/google/suite/hooks/test_drive.py
@@ -101,7 +101,7 @@ def test_ensure_folders_exists_when_no_folder_exists(self, mock_get_conn):
any_order=True,
)
- self.assertEqual("ID_4", result_value)
+ assert "ID_4" == result_value
@mock.patch("airflow.providers.google.suite.hooks.drive.GoogleDriveHook.get_conn")
def test_ensure_folders_exists_when_some_folders_exists(self, mock_get_conn):
@@ -143,7 +143,7 @@ def test_ensure_folders_exists_when_some_folders_exists(self, mock_get_conn):
any_order=True,
)
- self.assertEqual("ID_4", result_value)
+ assert "ID_4" == result_value
@mock.patch("airflow.providers.google.suite.hooks.drive.GoogleDriveHook.get_conn")
def test_ensure_folders_exists_when_all_folders_exists(self, mock_get_conn):
@@ -157,7 +157,7 @@ def test_ensure_folders_exists_when_all_folders_exists(self, mock_get_conn):
result_value = self.gdrive_hook._ensure_folders_exists("AAA/BBB/CCC/DDD")
mock_get_conn.return_value.files.return_value.create.assert_not_called()
- self.assertEqual("ID_4", result_value)
+ assert "ID_4" == result_value
@mock.patch("airflow.providers.google.suite.hooks.drive.MediaFileUpload")
@mock.patch("airflow.providers.google.suite.hooks.drive.GoogleDriveHook.get_conn")
@@ -183,7 +183,7 @@ def test_upload_file_to_root_directory(
)
]
)
- self.assertEqual(return_value, "FILE_ID")
+ assert return_value == "FILE_ID"
@mock.patch("airflow.providers.google.suite.hooks.drive.MediaFileUpload")
@mock.patch("airflow.providers.google.suite.hooks.drive.GoogleDriveHook.get_conn")
@@ -212,4 +212,4 @@ def test_upload_file_to_subdirectory(
)
]
)
- self.assertEqual(return_value, "FILE_ID")
+ assert return_value == "FILE_ID"
diff --git a/tests/providers/google/suite/hooks/test_sheets.py b/tests/providers/google/suite/hooks/test_sheets.py
index 3cf5d0b8645e7..2647ff59656b4 100644
--- a/tests/providers/google/suite/hooks/test_sheets.py
+++ b/tests/providers/google/suite/hooks/test_sheets.py
@@ -23,6 +23,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.google.suite.hooks.sheets import GSheetsHook
from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id
@@ -58,7 +60,7 @@ def test_gsheets_client_creation(self, mock_build, mock_authorize):
mock_build.assert_called_once_with(
'sheets', 'v4', http=mock_authorize.return_value, cache_discovery=False
)
- self.assertEqual(mock_build.return_value, result)
+ assert mock_build.return_value == result
@mock.patch("airflow.providers.google.suite.hooks.sheets.GSheetsHook.get_conn")
def test_get_values(self, get_conn):
@@ -72,7 +74,7 @@ def test_get_values(self, get_conn):
value_render_option=VALUE_RENDER_OPTION,
date_time_render_option=DATE_TIME_RENDER_OPTION,
)
- self.assertIs(result, VALUES)
+ assert result is VALUES
execute_method.assert_called_once_with(num_retries=NUM_RETRIES)
get_method.assert_called_once_with(
spreadsheetId=SPREADSHEET_ID,
@@ -94,7 +96,7 @@ def test_batch_get_values(self, get_conn):
value_render_option=VALUE_RENDER_OPTION,
date_time_render_option=DATE_TIME_RENDER_OPTION,
)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
execute_method.assert_called_once_with(num_retries=NUM_RETRIES)
batch_get_method.assert_called_once_with(
spreadsheetId=SPREADSHEET_ID,
@@ -120,7 +122,7 @@ def test_update_values(self, get_conn):
date_time_render_option=DATE_TIME_RENDER_OPTION,
)
body = {"range": RANGE_, "majorDimension": MAJOR_DIMENSION, "values": VALUES}
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
execute_method.assert_called_once_with(num_retries=NUM_RETRIES)
update_method.assert_called_once_with(
spreadsheetId=SPREADSHEET_ID,
@@ -158,7 +160,7 @@ def test_batch_update_values(self, get_conn):
"responseValueRenderOption": VALUE_RENDER_OPTION,
"responseDateTimeRenderOption": DATE_TIME_RENDER_OPTION,
}
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
execute_method.assert_called_once_with(num_retries=NUM_RETRIES)
batch_update_method.assert_called_once_with(spreadsheetId=SPREADSHEET_ID, body=body)
@@ -167,7 +169,7 @@ def test_batch_update_values_with_bad_data(self, get_conn):
batch_update_method = get_conn.return_value.spreadsheets.return_value.values.return_value.batchUpdate
execute_method = batch_update_method.return_value.execute
execute_method.return_value = API_RESPONSE
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
self.hook.batch_update_values(
spreadsheet_id=SPREADSHEET_ID,
ranges=['test!A1:B2', 'test!C1:C2'],
@@ -180,8 +182,8 @@ def test_batch_update_values_with_bad_data(self, get_conn):
)
batch_update_method.assert_not_called()
execute_method.assert_not_called()
- err = cm.exception
- self.assertIn("must be of equal length.", str(err))
+ err = ctx.value
+ assert "must be of equal length." in str(err)
@mock.patch("airflow.providers.google.suite.hooks.sheets.GSheetsHook.get_conn")
def test_append_values(self, get_conn):
@@ -200,7 +202,7 @@ def test_append_values(self, get_conn):
date_time_render_option=DATE_TIME_RENDER_OPTION,
)
body = {"range": RANGE_, "majorDimension": MAJOR_DIMENSION, "values": VALUES}
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
execute_method.assert_called_once_with(num_retries=NUM_RETRIES)
append_method.assert_called_once_with(
spreadsheetId=SPREADSHEET_ID,
@@ -220,7 +222,7 @@ def test_clear_values(self, get_conn):
execute_method.return_value = API_RESPONSE
result = self.hook.clear(spreadsheet_id=SPREADSHEET_ID, range_=RANGE_)
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
execute_method.assert_called_once_with(num_retries=NUM_RETRIES)
clear_method.assert_called_once_with(spreadsheetId=SPREADSHEET_ID, range=RANGE_)
@@ -231,7 +233,7 @@ def test_batch_clear_values(self, get_conn):
execute_method.return_value = API_RESPONSE
result = self.hook.batch_clear(spreadsheet_id=SPREADSHEET_ID, ranges=RANGES)
body = {"ranges": RANGES}
- self.assertIs(result, API_RESPONSE)
+ assert result is API_RESPONSE
execute_method.assert_called_once_with(num_retries=NUM_RETRIES)
batch_clear_method.assert_called_once_with(spreadsheetId=SPREADSHEET_ID, body=body)
diff --git a/tests/providers/google/suite/transfers/test_gcs_to_gdrive.py b/tests/providers/google/suite/transfers/test_gcs_to_gdrive.py
index 9344f30ffa2e8..096bc52e3c64e 100644
--- a/tests/providers/google/suite/transfers/test_gcs_to_gdrive.py
+++ b/tests/providers/google/suite/transfers/test_gcs_to_gdrive.py
@@ -18,6 +18,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.google.suite.transfers.gcs_to_gdrive import GCSToGoogleDriveOperator
@@ -170,5 +172,5 @@ def test_should_raise_exception_on_multiple_wildcard(
task = GCSToGoogleDriveOperator(
task_id="move_files", source_bucket="data", source_object="sales/*/*.avro", move_object=True
)
- with self.assertRaisesRegex(AirflowException, "Only one wildcard"):
+ with pytest.raises(AirflowException, match="Only one wildcard"):
task.execute(mock.MagicMock())
diff --git a/tests/providers/grpc/hooks/test_grpc.py b/tests/providers/grpc/hooks/test_grpc.py
index cdd8019279dca..ec1e7d6658e06 100644
--- a/tests/providers/grpc/hooks/test_grpc.py
+++ b/tests/providers/grpc/hooks/test_grpc.py
@@ -19,6 +19,8 @@
from io import StringIO
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowConfigException
from airflow.models import Connection
from airflow.providers.grpc.hooks.grpc import GrpcHook
@@ -78,7 +80,7 @@ def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel):
expected_url = "test:8080"
mock_insecure_channel.assert_called_once_with(expected_url)
- self.assertEqual(channel, mocked_channel)
+ assert channel == mocked_channel
@mock.patch('grpc.insecure_channel')
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
@@ -93,7 +95,7 @@ def test_connection_with_port(self, mock_get_connection, mock_insecure_channel):
expected_url = "test.com:1234"
mock_insecure_channel.assert_called_once_with(expected_url)
- self.assertEqual(channel, mocked_channel)
+ assert channel == mocked_channel
@mock.patch('airflow.providers.grpc.hooks.grpc.open')
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
@@ -117,7 +119,7 @@ def test_connection_with_ssl(
mock_open.assert_called_once_with("pem", "rb")
mock_channel_credentials.assert_called_once_with('credential')
mock_secure_channel.assert_called_once_with(expected_url, mock_credential_object)
- self.assertEqual(channel, mocked_channel)
+ assert channel == mocked_channel
@mock.patch('airflow.providers.grpc.hooks.grpc.open')
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
@@ -141,7 +143,7 @@ def test_connection_with_tls(
mock_open.assert_called_once_with("pem", "rb")
mock_channel_credentials.assert_called_once_with('credential')
mock_secure_channel.assert_called_once_with(expected_url, mock_credential_object)
- self.assertEqual(channel, mocked_channel)
+ assert channel == mocked_channel
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
@mock.patch('google.auth.jwt.OnDemandCredentials.from_signing_credentials')
@@ -164,7 +166,7 @@ def test_connection_with_jwt(
mock_google_cred.assert_called_once_with(mock_credential_object)
mock_secure_channel.assert_called_once_with(mock_credential_object, None, expected_url)
- self.assertEqual(channel, mocked_channel)
+ assert channel == mocked_channel
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
@mock.patch('google.auth.transport.requests.Request')
@@ -187,7 +189,7 @@ def test_connection_with_google_oauth(
mock_google_default_auth.assert_called_once_with(scopes=["grpc", "gcs"])
mock_secure_channel.assert_called_once_with(mock_credential_object, "request", expected_url)
- self.assertEqual(channel, mocked_channel)
+ assert channel == mocked_channel
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
def test_custom_connection(self, mock_get_connection):
@@ -198,7 +200,7 @@ def test_custom_connection(self, mock_get_connection):
channel = hook.get_conn()
- self.assertEqual(channel, mocked_channel)
+ assert channel == mocked_channel
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
def test_custom_connection_with_no_connection_func(self, mock_get_connection):
@@ -206,7 +208,7 @@ def test_custom_connection_with_no_connection_func(self, mock_get_connection):
mock_get_connection.return_value = conn
hook = GrpcHook("grpc_default")
- with self.assertRaises(AirflowConfigException):
+ with pytest.raises(AirflowConfigException):
hook.get_conn()
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
@@ -215,7 +217,7 @@ def test_connection_type_not_supported(self, mock_get_connection):
mock_get_connection.return_value = conn
hook = GrpcHook("grpc_default")
- with self.assertRaises(AirflowConfigException):
+ with pytest.raises(AirflowConfigException):
hook.get_conn()
@mock.patch('grpc.intercept_channel')
@@ -233,7 +235,7 @@ def test_connection_with_interceptors(
channel = hook.get_conn()
- self.assertEqual(channel, mocked_channel)
+ assert channel == mocked_channel
mock_intercept_channel.assert_called_once_with(mocked_channel, "test1")
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
@@ -249,7 +251,7 @@ def test_simple_run(self, mock_get_conn, mock_get_connection):
response = hook.run(StubClass, "single_call", data={'data': 'hello'})
- self.assertEqual(next(response), "hello")
+ assert next(response) == "hello"
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
@mock.patch('airflow.providers.grpc.hooks.grpc.GrpcHook.get_conn')
@@ -264,4 +266,4 @@ def test_stream_run(self, mock_get_conn, mock_get_connection):
response = hook.run(StubClass, "stream_call", data={'data': ['hello!', "hi"]})
- self.assertEqual(next(response), ["streaming", "call"])
+ assert next(response) == ["streaming", "call"]
diff --git a/tests/providers/grpc/operators/test_grpc.py b/tests/providers/grpc/operators/test_grpc.py
index 1bbed8b789c0c..3faea91fa8db6 100644
--- a/tests/providers/grpc/operators/test_grpc.py
+++ b/tests/providers/grpc/operators/test_grpc.py
@@ -94,7 +94,7 @@ def test_execute_with_callback(self, mock_hook):
operator.execute({})
mock_hook.assert_called_once_with("grpc_default", interceptors=None, custom_connection_func=None)
mocked_hook.run.assert_called_once_with(StubClass, "stream_call", data={}, streaming=False)
- self.assertTrue(("'value1'", "'value2'") not in mock_info.call_args_list)
+ assert ("'value1'", "'value2'") not in mock_info.call_args_list
mock_info.assert_any_call("Calling gRPC service")
callback.assert_any_call("value1", {})
callback.assert_any_call("value2", {})
diff --git a/tests/providers/hashicorp/_internal_client/test_vault_client.py b/tests/providers/hashicorp/_internal_client/test_vault_client.py
index 462cd2eda9f94..8635ca91216b6 100644
--- a/tests/providers/hashicorp/_internal_client/test_vault_client.py
+++ b/tests/providers/hashicorp/_internal_client/test_vault_client.py
@@ -20,6 +20,7 @@
from unittest.case import TestCase
from unittest.mock import mock_open, patch
+import pytest
from hvac.exceptions import InvalidPath, VaultError
from airflow.providers.hashicorp._internal_client.vault_client import _VaultClient # noqa
@@ -30,7 +31,7 @@ class TestVaultClient(TestCase):
def test_version_wrong(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
- with self.assertRaisesRegex(VaultError, 'The version is not supported: 4'):
+ with pytest.raises(VaultError, match='The version is not supported: 4'):
_VaultClient(auth_type="approle", kv_engine_version=4)
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -38,7 +39,7 @@ def test_custom_mount_point(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
vault_client = _VaultClient(auth_type="userpass", mount_point="custom")
- self.assertEqual("custom", vault_client.mount_point)
+ assert "custom" == vault_client.mount_point
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_version_one_init(self, mock_hvac):
@@ -46,7 +47,7 @@ def test_version_one_init(self, mock_hvac):
mock_hvac.Client.return_value = mock_client
vault_client = _VaultClient(auth_type="userpass", kv_engine_version=1)
- self.assertEqual(1, vault_client.kv_engine_version)
+ assert 1 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_approle(self, mock_hvac):
@@ -59,7 +60,7 @@ def test_approle(self, mock_hvac):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.auth_approle.assert_called_with(role_id="role", secret_id="pass")
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_approle_different_auth_mount_point(self, mock_hvac):
@@ -76,13 +77,13 @@ def test_approle_different_auth_mount_point(self, mock_hvac):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.auth_approle.assert_called_with(role_id="role", secret_id="pass", mount_point="other")
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_approle_missing_role(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
- with self.assertRaisesRegex(VaultError, "requires 'role_id'"):
+ with pytest.raises(VaultError, match="requires 'role_id'"):
_VaultClient(auth_type="approle", url="http://localhost:8180", secret_id="pass")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -100,7 +101,7 @@ def test_aws_iam(self, mock_hvac):
role="role",
)
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_aws_iam_different_auth_mount_point(self, mock_hvac):
@@ -120,7 +121,7 @@ def test_aws_iam_different_auth_mount_point(self, mock_hvac):
access_key='user', secret_key='pass', role="role", mount_point='other'
)
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_azure(self, mock_hvac):
@@ -143,7 +144,7 @@ def test_azure(self, mock_hvac):
client_secret="pass",
)
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_azure_different_auth_mount_point(self, mock_hvac):
@@ -168,13 +169,13 @@ def test_azure_different_auth_mount_point(self, mock_hvac):
mount_point="other",
)
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_azure_missing_resource(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
- with self.assertRaisesRegex(VaultError, "requires 'azure_resource'"):
+ with pytest.raises(VaultError, match="requires 'azure_resource'"):
_VaultClient(
auth_type="azure",
azure_tenant_id="tenant_id",
@@ -187,7 +188,7 @@ def test_azure_missing_resource(self, mock_hvac):
def test_azure_missing_tenant_id(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
- with self.assertRaisesRegex(VaultError, "requires 'azure_tenant_id'"):
+ with pytest.raises(VaultError, match="requires 'azure_tenant_id'"):
_VaultClient(
auth_type="azure",
azure_resource='resource',
@@ -218,7 +219,7 @@ def test_gcp(self, mock_hvac, mock_get_credentials, mock_get_scopes):
credentials="credentials",
)
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
@@ -244,7 +245,7 @@ def test_gcp_different_auth_mount_point(self, mock_hvac, mock_get_credentials, m
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.auth.gcp.configure.assert_called_with(credentials="credentials", mount_point="other")
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
@@ -271,7 +272,7 @@ def test_gcp_dict(self, mock_hvac, mock_get_credentials, mock_get_scopes):
credentials="credentials",
)
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_github(self, mock_hvac):
@@ -284,7 +285,7 @@ def test_github(self, mock_hvac):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.auth.github.login.assert_called_with(token="s.7AU0I51yv1Q1lxOIg1F3ZRAS")
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_github_different_auth_mount_point(self, mock_hvac):
@@ -300,13 +301,13 @@ def test_github_different_auth_mount_point(self, mock_hvac):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.auth.github.login.assert_called_with(token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", mount_point="other")
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_github_missing_token(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
- with self.assertRaisesRegex(VaultError, "'github' authentication type requires 'token'"):
+ with pytest.raises(VaultError, match="'github' authentication type requires 'token'"):
_VaultClient(auth_type="github", url="http://localhost:8180")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -322,7 +323,7 @@ def test_kubernetes_default_path(self, mock_hvac):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.auth_kubernetes.assert_called_with(role="kube_role", jwt="data")
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_kubernetes(self, mock_hvac):
@@ -340,7 +341,7 @@ def test_kubernetes(self, mock_hvac):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.auth_kubernetes.assert_called_with(role="kube_role", jwt="data")
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_kubernetes_different_auth_mount_point(self, mock_hvac):
@@ -359,20 +360,20 @@ def test_kubernetes_different_auth_mount_point(self, mock_hvac):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.auth_kubernetes.assert_called_with(role="kube_role", jwt="data", mount_point="other")
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_kubernetes_missing_role(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
- with self.assertRaisesRegex(VaultError, "requires 'kubernetes_role'"):
+ with pytest.raises(VaultError, match="requires 'kubernetes_role'"):
_VaultClient(auth_type="kubernetes", kubernetes_jwt_path="path", url="http://localhost:8180")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_kubernetes_kubernetes_jwt_path_none(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
- with self.assertRaisesRegex(VaultError, "requires 'kubernetes_jwt_path'"):
+ with pytest.raises(VaultError, match="requires 'kubernetes_jwt_path'"):
_VaultClient(
auth_type="kubernetes",
kubernetes_role='kube_role',
@@ -391,7 +392,7 @@ def test_ldap(self, mock_hvac):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.auth.ldap.login.assert_called_with(username="user", password="pass")
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_ldap_different_auth_mount_point(self, mock_hvac):
@@ -408,20 +409,20 @@ def test_ldap_different_auth_mount_point(self, mock_hvac):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.auth.ldap.login.assert_called_with(username="user", password="pass", mount_point="other")
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_radius_missing_host(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
- with self.assertRaisesRegex(VaultError, "radius_host"):
+ with pytest.raises(VaultError, match="radius_host"):
_VaultClient(auth_type="radius", radius_secret="pass", url="http://localhost:8180")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_radius_missing_secret(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
- with self.assertRaisesRegex(VaultError, "radius_secret"):
+ with pytest.raises(VaultError, match="radius_secret"):
_VaultClient(auth_type="radius", radius_host="radhost", url="http://localhost:8180")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -435,7 +436,7 @@ def test_radius(self, mock_hvac):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.auth.radius.configure.assert_called_with(host="radhost", secret="pass", port=None)
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_radius_different_auth_mount_point(self, mock_hvac):
@@ -454,7 +455,7 @@ def test_radius_different_auth_mount_point(self, mock_hvac):
host="radhost", secret="pass", port=None, mount_point="other"
)
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_radius_port(self, mock_hvac):
@@ -471,13 +472,13 @@ def test_radius_port(self, mock_hvac):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.auth.radius.configure.assert_called_with(host="radhost", secret="pass", port=8110)
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_token_missing_token(self, mock_hvac):
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
- with self.assertRaisesRegex(VaultError, "'token' authentication type requires 'token'"):
+ with pytest.raises(VaultError, match="'token' authentication type requires 'token'"):
_VaultClient(auth_type="token", url="http://localhost:8180")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -490,9 +491,9 @@ def test_token(self, mock_hvac):
client = vault_client.client
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.is_authenticated.assert_called_with()
- self.assertEqual("s.7AU0I51yv1Q1lxOIg1F3ZRAS", client.token)
- self.assertEqual(2, vault_client.kv_engine_version)
- self.assertEqual("secret", vault_client.mount_point)
+ assert "s.7AU0I51yv1Q1lxOIg1F3ZRAS" == client.token
+ assert 2 == vault_client.kv_engine_version
+ assert "secret" == vault_client.mount_point
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_token_path(self, mock_hvac):
@@ -506,9 +507,9 @@ def test_token_path(self, mock_hvac):
client = vault_client.client
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.is_authenticated.assert_called_with()
- self.assertEqual("s.7AU0I51yv1Q1lxOIg1F3ZRAS", client.token)
- self.assertEqual(2, vault_client.kv_engine_version)
- self.assertEqual("secret", vault_client.mount_point)
+ assert "s.7AU0I51yv1Q1lxOIg1F3ZRAS" == client.token
+ assert 2 == vault_client.kv_engine_version
+ assert "secret" == vault_client.mount_point
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_default_auth_type(self, mock_hvac):
@@ -518,10 +519,10 @@ def test_default_auth_type(self, mock_hvac):
client = vault_client.client
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.is_authenticated.assert_called_with()
- self.assertEqual("s.7AU0I51yv1Q1lxOIg1F3ZRAS", client.token)
- self.assertEqual("token", vault_client.auth_type)
- self.assertEqual(2, vault_client.kv_engine_version)
- self.assertEqual("secret", vault_client.mount_point)
+ assert "s.7AU0I51yv1Q1lxOIg1F3ZRAS" == client.token
+ assert "token" == vault_client.auth_type
+ assert 2 == vault_client.kv_engine_version
+ assert "secret" == vault_client.mount_point
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_userpass(self, mock_hvac):
@@ -534,7 +535,7 @@ def test_userpass(self, mock_hvac):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.auth_userpass.assert_called_with(username="user", password="pass")
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_userpass_different_auth_mount_point(self, mock_hvac):
@@ -551,7 +552,7 @@ def test_userpass_different_auth_mount_point(self, mock_hvac):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
client.auth_userpass.assert_called_with(username="user", password="pass", mount_point="other")
client.is_authenticated.assert_called_with()
- self.assertEqual(2, vault_client.kv_engine_version)
+ assert 2 == vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_get_non_existing_key_v2(self, mock_hvac):
@@ -563,7 +564,7 @@ def test_get_non_existing_key_v2(self, mock_hvac):
auth_type="token", token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", url="http://localhost:8180"
)
secret = vault_client.get_secret(secret_path="missing")
- self.assertIsNone(secret)
+ assert secret is None
mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
mount_point='secret', path='missing', version=None
)
@@ -582,8 +583,8 @@ def test_get_non_existing_key_v2_different_auth(self, mock_hvac):
url="http://localhost:8180",
)
secret = vault_client.get_secret(secret_path="missing")
- self.assertIsNone(secret)
- self.assertEqual("secret", vault_client.mount_point)
+ assert secret is None
+ assert "secret" == vault_client.mount_point
mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
mount_point='secret', path='missing', version=None
)
@@ -603,7 +604,7 @@ def test_get_non_existing_key_v1(self, mock_hvac):
url="http://localhost:8180",
)
secret = vault_client.get_secret(secret_path="missing")
- self.assertIsNone(secret)
+ assert secret is None
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(mount_point='secret', path='missing')
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -638,7 +639,7 @@ def test_get_existing_key_v2(self, mock_hvac):
url="http://localhost:8180",
)
secret = vault_client.get_secret(secret_path="missing")
- self.assertEqual({'secret_key': 'secret_value'}, secret)
+ assert {'secret_key': 'secret_value'} == secret
mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
mount_point='secret', path='missing', version=None
)
@@ -675,7 +676,7 @@ def test_get_existing_key_v2_version(self, mock_hvac):
url="http://localhost:8180",
)
secret = vault_client.get_secret(secret_path="missing", secret_version=1)
- self.assertEqual({'secret_key': 'secret_value'}, secret)
+ assert {'secret_key': 'secret_value'} == secret
mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
mount_point='secret', path='missing', version=1
)
@@ -705,7 +706,7 @@ def test_get_existing_key_v1(self, mock_hvac):
url="http://localhost:8180",
)
secret = vault_client.get_secret(secret_path="missing")
- self.assertEqual({'value': 'world'}, secret)
+ assert {'value': 'world'} == secret
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(mount_point='secret', path='missing')
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -734,7 +735,7 @@ def test_get_existing_key_v1_different_auth_mount_point(self, mock_hvac):
url="http://localhost:8180",
)
secret = vault_client.get_secret(secret_path="missing")
- self.assertEqual({'value': 'world'}, secret)
+ assert {'value': 'world'} == secret
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(mount_point='secret', path='missing')
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -747,7 +748,7 @@ def test_get_existing_key_v1_version(self, mock_hvac):
url="http://localhost:8180",
kv_engine_version=1,
)
- with self.assertRaisesRegex(VaultError, "Secret version"):
+ with pytest.raises(VaultError, match="Secret version"):
vault_client.get_secret(secret_path="missing", secret_version=1)
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -778,29 +779,26 @@ def test_get_secret_metadata_v2(self, mock_hvac):
auth_type="token", token="s.7AU0I51yv1Q1lxOIg1F3ZRAS", url="http://localhost:8180"
)
metadata = vault_client.get_secret_metadata(secret_path="missing")
- self.assertEqual(
- {
- 'request_id': '94011e25-f8dc-ec29-221b-1f9c1d9ad2ae',
- 'lease_id': '',
- 'renewable': False,
- 'lease_duration': 0,
- 'metadata': [
- {
- 'created_time': '2020-03-16T21:01:43.331126Z',
- 'deletion_time': '',
- 'destroyed': False,
- 'version': 1,
- },
- {
- 'created_time': '2020-03-16T21:01:43.331126Z',
- 'deletion_time': '',
- 'destroyed': False,
- 'version': 2,
- },
- ],
- },
- metadata,
- )
+ assert {
+ 'request_id': '94011e25-f8dc-ec29-221b-1f9c1d9ad2ae',
+ 'lease_id': '',
+ 'renewable': False,
+ 'lease_duration': 0,
+ 'metadata': [
+ {
+ 'created_time': '2020-03-16T21:01:43.331126Z',
+ 'deletion_time': '',
+ 'destroyed': False,
+ 'version': 1,
+ },
+ {
+ 'created_time': '2020-03-16T21:01:43.331126Z',
+ 'deletion_time': '',
+ 'destroyed': False,
+ 'version': 2,
+ },
+ ],
+ } == metadata
mock_client.secrets.kv.v2.read_secret_metadata.assert_called_once_with(
mount_point='secret', path='missing'
)
@@ -818,9 +816,7 @@ def test_get_secret_metadata_v1(self, mock_hvac):
kv_engine_version=1,
url="http://localhost:8180",
)
- with self.assertRaisesRegex(
- VaultError, "Metadata might only be used with version 2 of the KV engine."
- ):
+ with pytest.raises(VaultError, match="Metadata might only be used with version 2 of the KV engine."):
vault_client.get_secret_metadata(secret_path="missing")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -854,27 +850,24 @@ def test_get_secret_including_metadata_v2(self, mock_hvac):
url="http://localhost:8180",
)
metadata = vault_client.get_secret_including_metadata(secret_path="missing")
- self.assertEqual(
- {
- 'request_id': '94011e25-f8dc-ec29-221b-1f9c1d9ad2ae',
- 'lease_id': '',
- 'renewable': False,
- 'lease_duration': 0,
- 'data': {
- 'data': {'secret_key': 'secret_value'},
- 'metadata': {
- 'created_time': '2020-03-16T21:01:43.331126Z',
- 'deletion_time': '',
- 'destroyed': False,
- 'version': 1,
- },
+ assert {
+ 'request_id': '94011e25-f8dc-ec29-221b-1f9c1d9ad2ae',
+ 'lease_id': '',
+ 'renewable': False,
+ 'lease_duration': 0,
+ 'data': {
+ 'data': {'secret_key': 'secret_value'},
+ 'metadata': {
+ 'created_time': '2020-03-16T21:01:43.331126Z',
+ 'deletion_time': '',
+ 'destroyed': False,
+ 'version': 1,
},
- 'wrap_info': None,
- 'warnings': None,
- 'auth': None,
},
- metadata,
- )
+ 'wrap_info': None,
+ 'warnings': None,
+ 'auth': None,
+ } == metadata
mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
mount_point='secret', path='missing', version=None
)
@@ -892,9 +885,7 @@ def test_get_secret_including_metadata_v1(self, mock_hvac):
kv_engine_version=1,
url="http://localhost:8180",
)
- with self.assertRaisesRegex(
- VaultError, "Metadata might only be used with version 2 of the KV engine."
- ):
+ with pytest.raises(VaultError, match="Metadata might only be used with version 2 of the KV engine."):
vault_client.get_secret_including_metadata(secret_path="missing")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -926,7 +917,7 @@ def test_create_or_update_secret_v2_method(self, mock_hvac):
radius_secret="pass",
url="http://localhost:8180",
)
- with self.assertRaisesRegex(VaultError, "The method parameter is only valid for version 1"):
+ with pytest.raises(VaultError, match="The method parameter is only valid for version 1"):
vault_client.create_or_update_secret(secret_path="path", secret={'key': 'value'}, method="post")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -977,7 +968,7 @@ def test_create_or_update_secret_v1_cas(self, mock_hvac):
kv_engine_version=1,
url="http://localhost:8180",
)
- with self.assertRaisesRegex(VaultError, "The cas parameter is only valid for version 2"):
+ with pytest.raises(VaultError, match="The cas parameter is only valid for version 2"):
vault_client.create_or_update_secret(secret_path="path", secret={'key': 'value'}, cas=10)
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
diff --git a/tests/providers/hashicorp/hooks/test_vault.py b/tests/providers/hashicorp/hooks/test_vault.py
index f4bf4cc1b33e3..df92c30d67677 100644
--- a/tests/providers/hashicorp/hooks/test_vault.py
+++ b/tests/providers/hashicorp/hooks/test_vault.py
@@ -20,6 +20,7 @@
from unittest.case import TestCase
from unittest.mock import PropertyMock, mock_open, patch
+import pytest
from hvac.exceptions import VaultError
from parameterized import parameterized
@@ -54,7 +55,7 @@ def test_version_not_int(self, mock_hvac, mock_get_connection):
kwargs = {
"vault_conn_id": "vault_conn_id",
}
- with self.assertRaisesRegex(VaultError, 'The version is not an int: text'):
+ with pytest.raises(VaultError, match='The version is not an int: text'):
VaultHook(**kwargs)
@parameterized.expand(
@@ -78,7 +79,7 @@ def test_version(self, version, expected_version, mock_hvac, mock_get_connection
"vault_conn_id": "vault_conn_id",
}
test_hook = VaultHook(**kwargs)
- self.assertEqual(expected_version, test_hook.vault_client.kv_engine_version)
+ assert expected_version == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -97,7 +98,7 @@ def test_custom_mount_point_dejson(self, mock_hvac, mock_get_connection):
"vault_conn_id": "vault_conn_id",
}
test_hook = VaultHook(**kwargs)
- self.assertEqual("custom", test_hook.vault_client.mount_point)
+ assert "custom" == test_hook.vault_client.mount_point
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -114,8 +115,8 @@ def test_custom_auth_mount_point_init_params(self, mock_hvac, mock_get_connectio
mock_connection.extra_dejson.get.side_effect = connection_dict.get
kwargs = {"vault_conn_id": "vault_conn_id", "auth_mount_point": "custom"}
test_hook = VaultHook(**kwargs)
- self.assertEqual("secret", test_hook.vault_client.mount_point)
- self.assertEqual("custom", test_hook.vault_client.auth_mount_point)
+ assert "secret" == test_hook.vault_client.mount_point
+ assert "custom" == test_hook.vault_client.auth_mount_point
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -132,8 +133,8 @@ def test_custom_auth_mount_point_dejson(self, mock_hvac, mock_get_connection):
"vault_conn_id": "vault_conn_id",
}
test_hook = VaultHook(**kwargs)
- self.assertEqual("secret", test_hook.vault_client.mount_point)
- self.assertEqual("custom", test_hook.vault_client.auth_mount_point)
+ assert "secret" == test_hook.vault_client.mount_point
+ assert "custom" == test_hook.vault_client.auth_mount_point
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -153,7 +154,7 @@ def test_version_one_dejson(self, mock_hvac, mock_get_connection):
"vault_conn_id": "vault_conn_id",
}
test_hook = VaultHook(**kwargs)
- self.assertEqual(1, test_hook.vault_client.kv_engine_version)
+ assert 1 == test_hook.vault_client.kv_engine_version
@parameterized.expand(
[
@@ -186,7 +187,7 @@ def test_protocol(self, protocol, expected_url, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url=expected_url)
test_client.auth_approle.assert_called_with(role_id="role", secret_id="pass")
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -212,7 +213,7 @@ def test_approle_init_params(self, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.auth_approle.assert_called_with(role_id="role", secret_id="pass")
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -238,7 +239,7 @@ def test_approle_dejson(self, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.auth_approle.assert_called_with(role_id="role", secret_id="pass")
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -263,7 +264,7 @@ def test_aws_iam_init_params(self, mock_hvac, mock_get_connection):
role="role",
)
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -319,7 +320,7 @@ def test_azure_init_params(self, mock_hvac, mock_get_connection):
client_secret="pass",
)
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -351,7 +352,7 @@ def test_azure_dejson(self, mock_hvac, mock_get_connection):
client_secret="pass",
)
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
@@ -387,7 +388,7 @@ def test_gcp_init_params(self, mock_hvac, mock_get_connection, mock_get_credenti
credentials="credentials",
)
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
@@ -424,7 +425,7 @@ def test_gcp_dejson(self, mock_hvac, mock_get_connection, mock_get_credentials,
credentials="credentials",
)
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
@@ -461,7 +462,7 @@ def test_gcp_dict_dejson(self, mock_hvac, mock_get_connection, mock_get_credenti
credentials="credentials",
)
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -485,7 +486,7 @@ def test_github_init_params(self, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.auth.github.login.assert_called_with(token="pass")
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -510,7 +511,7 @@ def test_github_dejson(self, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.auth.github.login.assert_called_with(token="pass")
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -537,7 +538,7 @@ def test_kubernetes_default_path(self, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.auth_kubernetes.assert_called_with(role="kube_role", jwt="data")
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -565,7 +566,7 @@ def test_kubernetes_init_params(self, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.auth_kubernetes.assert_called_with(role="kube_role", jwt="data")
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -592,7 +593,7 @@ def test_kubernetes_dejson(self, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.auth_kubernetes.assert_called_with(role="kube_role", jwt="data")
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -616,7 +617,7 @@ def test_ldap_init_params(self, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.auth.ldap.login.assert_called_with(username="user", password="pass")
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -641,7 +642,7 @@ def test_ldap_dejson(self, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.auth.ldap.login.assert_called_with(username="user", password="pass")
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -666,7 +667,7 @@ def test_radius_init_params(self, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.auth.radius.configure.assert_called_with(host="radhost", secret="pass", port=None)
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -692,7 +693,7 @@ def test_radius_init_params_port(self, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.auth.radius.configure.assert_called_with(host="radhost", secret="pass", port=8123)
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -719,7 +720,7 @@ def test_radius_dejson(self, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.auth.radius.configure.assert_called_with(host="radhost", secret="pass", port=8123)
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -740,7 +741,7 @@ def test_radius_dejson_wrong_port(self, mock_hvac, mock_get_connection):
"vault_conn_id": "vault_conn_id",
}
- with self.assertRaisesRegex(VaultError, "Radius port was wrong: wrong"):
+ with pytest.raises(VaultError, match="Radius port was wrong: wrong"):
VaultHook(**kwargs)
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@@ -759,9 +760,9 @@ def test_token_init_params(self, mock_hvac, mock_get_connection):
test_client = test_hook.get_conn()
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.is_authenticated.assert_called_with()
- self.assertEqual("pass", test_client.token)
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
- self.assertEqual("secret", test_hook.vault_client.mount_point)
+ assert "pass" == test_client.token
+ assert 2 == test_hook.vault_client.kv_engine_version
+ assert "secret" == test_hook.vault_client.mount_point
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -785,8 +786,8 @@ def test_token_dejson(self, mock_hvac, mock_get_connection):
test_client = test_hook.get_conn()
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.is_authenticated.assert_called_with()
- self.assertEqual("pass", test_client.token)
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert "pass" == test_client.token
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -807,7 +808,7 @@ def test_userpass_init_params(self, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.auth_userpass.assert_called_with(username="user", password="pass")
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -832,7 +833,7 @@ def test_userpass_dejson(self, mock_hvac, mock_get_connection):
mock_hvac.Client.assert_called_with(url='http://localhost:8180')
test_client.auth_userpass.assert_called_with(username="user", password="pass")
test_client.is_authenticated.assert_called_with()
- self.assertEqual(2, test_hook.vault_client.kv_engine_version)
+ assert 2 == test_hook.vault_client.kv_engine_version
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -868,7 +869,7 @@ def test_get_existing_key_v2(self, mock_hvac, mock_get_connection):
test_hook = VaultHook(**kwargs)
secret = test_hook.get_secret(secret_path="missing")
- self.assertEqual({'secret_key': 'secret_value'}, secret)
+ assert {'secret_key': 'secret_value'} == secret
mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
mount_point='secret', path='missing', version=None
)
@@ -907,7 +908,7 @@ def test_get_existing_key_v2_version(self, mock_hvac, mock_get_connection):
test_hook = VaultHook(**kwargs)
secret = test_hook.get_secret(secret_path="missing", secret_version=1)
- self.assertEqual({'secret_key': 'secret_value'}, secret)
+ assert {'secret_key': 'secret_value'} == secret
mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
mount_point='secret', path='missing', version=1
)
@@ -938,7 +939,7 @@ def test_get_existing_key_v1(self, mock_hvac, mock_get_connection):
test_hook = VaultHook(**kwargs)
secret = test_hook.get_secret(secret_path="missing")
- self.assertEqual({'value': 'world'}, secret)
+ assert {'value': 'world'} == secret
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(mount_point='secret', path='missing')
@mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection")
@@ -977,29 +978,26 @@ def test_get_secret_metadata_v2(self, mock_hvac, mock_get_connection):
test_hook = VaultHook(**kwargs)
metadata = test_hook.get_secret_metadata(secret_path="missing")
- self.assertEqual(
- {
- 'request_id': '94011e25-f8dc-ec29-221b-1f9c1d9ad2ae',
- 'lease_id': '',
- 'renewable': False,
- 'lease_duration': 0,
- 'metadata': [
- {
- 'created_time': '2020-03-16T21:01:43.331126Z',
- 'deletion_time': '',
- 'destroyed': False,
- 'version': 1,
- },
- {
- 'created_time': '2020-03-16T21:01:43.331126Z',
- 'deletion_time': '',
- 'destroyed': False,
- 'version': 2,
- },
- ],
- },
- metadata,
- )
+ assert {
+ 'request_id': '94011e25-f8dc-ec29-221b-1f9c1d9ad2ae',
+ 'lease_id': '',
+ 'renewable': False,
+ 'lease_duration': 0,
+ 'metadata': [
+ {
+ 'created_time': '2020-03-16T21:01:43.331126Z',
+ 'deletion_time': '',
+ 'destroyed': False,
+ 'version': 1,
+ },
+ {
+ 'created_time': '2020-03-16T21:01:43.331126Z',
+ 'deletion_time': '',
+ 'destroyed': False,
+ 'version': 2,
+ },
+ ],
+ } == metadata
mock_client.secrets.kv.v2.read_secret_metadata.assert_called_once_with(
mount_point='secret', path='missing'
)
@@ -1038,27 +1036,24 @@ def test_get_secret_including_metadata_v2(self, mock_hvac, mock_get_connection):
test_hook = VaultHook(**kwargs)
metadata = test_hook.get_secret_including_metadata(secret_path="missing")
- self.assertEqual(
- {
- 'request_id': '94011e25-f8dc-ec29-221b-1f9c1d9ad2ae',
- 'lease_id': '',
- 'renewable': False,
- 'lease_duration': 0,
- 'data': {
- 'data': {'secret_key': 'secret_value'},
- 'metadata': {
- 'created_time': '2020-03-16T21:01:43.331126Z',
- 'deletion_time': '',
- 'destroyed': False,
- 'version': 1,
- },
+ assert {
+ 'request_id': '94011e25-f8dc-ec29-221b-1f9c1d9ad2ae',
+ 'lease_id': '',
+ 'renewable': False,
+ 'lease_duration': 0,
+ 'data': {
+ 'data': {'secret_key': 'secret_value'},
+ 'metadata': {
+ 'created_time': '2020-03-16T21:01:43.331126Z',
+ 'deletion_time': '',
+ 'destroyed': False,
+ 'version': 1,
},
- 'wrap_info': None,
- 'warnings': None,
- 'auth': None,
},
- metadata,
- )
+ 'wrap_info': None,
+ 'warnings': None,
+ 'auth': None,
+ } == metadata
mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
mount_point='secret', path='missing', version=None
)
diff --git a/tests/providers/hashicorp/secrets/test_vault.py b/tests/providers/hashicorp/secrets/test_vault.py
index f71cfab1efdc6..da293d3204fba 100644
--- a/tests/providers/hashicorp/secrets/test_vault.py
+++ b/tests/providers/hashicorp/secrets/test_vault.py
@@ -17,6 +17,7 @@
from unittest import TestCase, mock
+import pytest
from hvac.exceptions import InvalidPath, VaultError
from airflow.providers.hashicorp.secrets.vault import VaultBackend
@@ -56,7 +57,7 @@ def test_get_conn_uri(self, mock_hvac):
test_client = VaultBackend(**kwargs)
returned_uri = test_client.get_conn_uri(conn_id="test_postgres")
- self.assertEqual('postgresql://airflow:airflow@host:5432/airflow', returned_uri)
+ assert 'postgresql://airflow:airflow@host:5432/airflow' == returned_uri
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_get_conn_uri_engine_version_1(self, mock_hvac):
@@ -87,7 +88,7 @@ def test_get_conn_uri_engine_version_1(self, mock_hvac):
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(
mount_point='airflow', path='connections/test_postgres'
)
- self.assertEqual('postgresql://airflow:airflow@host:5432/airflow', returned_uri)
+ assert 'postgresql://airflow:airflow@host:5432/airflow' == returned_uri
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_get_conn_uri_engine_version_1_custom_auth_mount_point(self, mock_hvac):
@@ -115,12 +116,12 @@ def test_get_conn_uri_engine_version_1_custom_auth_mount_point(self, mock_hvac):
}
test_client = VaultBackend(**kwargs)
- self.assertEqual("custom", test_client.vault_client.auth_mount_point)
+ assert "custom" == test_client.vault_client.auth_mount_point
returned_uri = test_client.get_conn_uri(conn_id="test_postgres")
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(
mount_point='airflow', path='connections/test_postgres'
)
- self.assertEqual('postgresql://airflow:airflow@host:5432/airflow', returned_uri)
+ assert 'postgresql://airflow:airflow@host:5432/airflow' == returned_uri
@mock.patch.dict(
'os.environ',
@@ -148,11 +149,11 @@ def test_get_conn_uri_non_existent_key(self, mock_hvac):
}
test_client = VaultBackend(**kwargs)
- self.assertIsNone(test_client.get_conn_uri(conn_id="test_mysql"))
+ assert test_client.get_conn_uri(conn_id="test_mysql") is None
mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
mount_point='airflow', path='connections/test_mysql', version=None
)
- self.assertEqual([], test_client.get_connections(conn_id="test_mysql"))
+ assert [] == test_client.get_connections(conn_id="test_mysql")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_get_variable_value(self, mock_hvac):
@@ -187,7 +188,7 @@ def test_get_variable_value(self, mock_hvac):
test_client = VaultBackend(**kwargs)
returned_uri = test_client.get_variable("hello")
- self.assertEqual('world', returned_uri)
+ assert 'world' == returned_uri
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_get_variable_value_engine_version_1(self, mock_hvac):
@@ -218,7 +219,7 @@ def test_get_variable_value_engine_version_1(self, mock_hvac):
mock_client.secrets.kv.v1.read_secret.assert_called_once_with(
mount_point='airflow', path='variables/hello'
)
- self.assertEqual('world', returned_uri)
+ assert 'world' == returned_uri
@mock.patch.dict(
'os.environ',
@@ -246,11 +247,11 @@ def test_get_variable_value_non_existent_key(self, mock_hvac):
}
test_client = VaultBackend(**kwargs)
- self.assertIsNone(test_client.get_variable("hello"))
+ assert test_client.get_variable("hello") is None
mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with(
mount_point='airflow', path='variables/hello', version=None
)
- self.assertIsNone(test_client.get_variable("hello"))
+ assert test_client.get_variable("hello") is None
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_auth_failure_raises_error(self, mock_hvac):
@@ -266,7 +267,7 @@ def test_auth_failure_raises_error(self, mock_hvac):
"token": "test_wrong_token",
}
- with self.assertRaisesRegex(VaultError, "Vault Authentication Error!"):
+ with pytest.raises(VaultError, match="Vault Authentication Error!"):
VaultBackend(**kwargs).get_connections(conn_id='test')
def test_auth_type_kubernetes_with_unreadable_jwt_raises_error(self):
@@ -278,7 +279,7 @@ def test_auth_type_kubernetes_with_unreadable_jwt_raises_error(self):
"url": "http://127.0.0.1:8200",
}
- with self.assertRaisesRegex(FileNotFoundError, path):
+ with pytest.raises(FileNotFoundError, match=path):
VaultBackend(**kwargs).get_connections(conn_id='test')
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -314,7 +315,7 @@ def test_get_config_value(self, mock_hvac):
test_client = VaultBackend(**kwargs)
returned_uri = test_client.get_config("sql_alchemy_conn")
- self.assertEqual('sqlite:////Users/airflow/airflow/airflow.db', returned_uri)
+ assert 'sqlite:////Users/airflow/airflow/airflow.db' == returned_uri
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
def test_connections_path_none_value(self, mock_hvac):
@@ -330,7 +331,7 @@ def test_connections_path_none_value(self, mock_hvac):
}
test_client = VaultBackend(**kwargs)
- self.assertIsNone(test_client.get_conn_uri(conn_id="test"))
+ assert test_client.get_conn_uri(conn_id="test") is None
mock_hvac.Client.assert_not_called()
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -347,7 +348,7 @@ def test_variables_path_none_value(self, mock_hvac):
}
test_client = VaultBackend(**kwargs)
- self.assertIsNone(test_client.get_variable("hello"))
+ assert test_client.get_variable("hello") is None
mock_hvac.Client.assert_not_called()
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@@ -364,5 +365,5 @@ def test_config_path_none_value(self, mock_hvac):
}
test_client = VaultBackend(**kwargs)
- self.assertIsNone(test_client.get_config("test"))
+ assert test_client.get_config("test") is None
mock_hvac.Client.assert_not_called()
diff --git a/tests/providers/http/hooks/test_http.py b/tests/providers/http/hooks/test_http.py
index 777aa7c3afa13..816adc3686fa8 100644
--- a/tests/providers/http/hooks/test_http.py
+++ b/tests/providers/http/hooks/test_http.py
@@ -19,6 +19,7 @@
import unittest
from unittest import mock
+import pytest
import requests
import requests_mock
import tenacity
@@ -54,7 +55,7 @@ def test_raise_for_status_with_200(self, m):
m.get('http://test:8080/v1/test', status_code=200, text='{"status":{"status": 200}}', reason='OK')
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
resp = self.get_hook.run('v1/test')
- self.assertEqual(resp.text, '{"status":{"status": 200}}')
+ assert resp.text == '{"status":{"status": 200}}'
@requests_mock.mock()
@mock.patch('requests.Session')
@@ -91,15 +92,15 @@ def test_get_request_do_not_raise_for_status_if_check_response_is_false(self, m)
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
resp = self.get_hook.run('v1/test', extra_options={'check_response': False})
- self.assertEqual(resp.text, '{"status":{"status": 404}}')
+ assert resp.text == '{"status":{"status": 404}}'
@requests_mock.mock()
def test_hook_contains_header_from_extra_field(self, mock_requests):
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
expected_conn = get_airflow_connection()
conn = self.get_hook.get_conn()
- self.assertDictContainsSubset(json.loads(expected_conn.extra), conn.headers)
- self.assertEqual(conn.headers.get('bareer'), 'test')
+ assert dict(conn.headers, **json.loads(expected_conn.extra)) == conn.headers
+ assert conn.headers.get('bareer') == 'test'
@requests_mock.mock()
@mock.patch('requests.Request')
@@ -119,18 +120,18 @@ def test_hook_with_method_in_lowercase(self, mock_requests, request_mock):
@requests_mock.mock()
def test_hook_uses_provided_header(self, mock_requests):
conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"})
- self.assertEqual(conn.headers.get('bareer'), "newT0k3n")
+ assert conn.headers.get('bareer') == "newT0k3n"
@requests_mock.mock()
def test_hook_has_no_header_from_extra(self, mock_requests):
conn = self.get_hook.get_conn()
- self.assertIsNone(conn.headers.get('bareer'))
+ assert conn.headers.get('bareer') is None
@requests_mock.mock()
def test_hooks_header_from_extra_is_overridden(self, mock_requests):
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"})
- self.assertEqual(conn.headers.get('bareer'), 'newT0k3n')
+ assert conn.headers.get('bareer') == 'newT0k3n'
@requests_mock.mock()
def test_post_request(self, mock_requests):
@@ -140,7 +141,7 @@ def test_post_request(self, mock_requests):
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
resp = self.post_hook.run('v1/test')
- self.assertEqual(resp.status_code, 200)
+ assert resp.status_code == 200
@requests_mock.mock()
def test_post_request_with_error_code(self, mock_requests):
@@ -152,7 +153,7 @@ def test_post_request_with_error_code(self, mock_requests):
)
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.post_hook.run('v1/test')
@requests_mock.mock()
@@ -166,7 +167,7 @@ def test_post_request_do_not_raise_for_status_if_check_response_is_false(self, m
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
resp = self.post_hook.run('v1/test', extra_options={'check_response': False})
- self.assertEqual(resp.status_code, 418)
+ assert resp.status_code == 418
@mock.patch('airflow.providers.http.hooks.http.requests.Session')
def test_retry_on_conn_error(self, mocked_session):
@@ -182,9 +183,9 @@ def send_and_raise(unused_request, **kwargs):
mocked_session().send.side_effect = send_and_raise
# The job failed for some reason
- with self.assertRaises(tenacity.RetryError):
+ with pytest.raises(tenacity.RetryError):
self.get_hook.run_with_advanced_retry(endpoint='v1/test', _retry_args=retry_args)
- self.assertEqual(self.get_hook._retry_obj.stop.max_attempt_number + 1, mocked_session.call_count)
+ assert self.get_hook._retry_obj.stop.max_attempt_number + 1 == mocked_session.call_count
@requests_mock.mock()
def test_run_with_advanced_retry(self, m):
@@ -199,7 +200,7 @@ def test_run_with_advanced_retry(self, m):
)
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
response = self.get_hook.run_with_advanced_retry(endpoint='v1/test', _retry_args=retry_args)
- self.assertIsInstance(response, requests.Response)
+ assert isinstance(response, requests.Response)
def test_header_from_extra_and_run_method_are_merged(self):
def run_and_return(unused_session, prepped_request, unused_extra_options, **kwargs):
@@ -212,8 +213,8 @@ def run_and_return(unused_session, prepped_request, unused_extra_options, **kwar
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
prepared_request = self.get_hook.run('v1/test', headers={'some_other_header': 'test'})
actual = dict(prepared_request.headers)
- self.assertEqual(actual.get('bareer'), 'test')
- self.assertEqual(actual.get('some_other_header'), 'test')
+ assert actual.get('bareer') == 'test'
+ assert actual.get('some_other_header') == 'test'
@mock.patch('airflow.providers.http.hooks.http.HttpHook.get_connection')
def test_http_connection(self, mock_get_connection):
@@ -221,7 +222,7 @@ def test_http_connection(self, mock_get_connection):
mock_get_connection.return_value = conn
hook = HttpHook()
hook.get_conn({})
- self.assertEqual(hook.base_url, 'http://localhost')
+ assert hook.base_url == 'http://localhost'
@mock.patch('airflow.providers.http.hooks.http.HttpHook.get_connection')
def test_https_connection(self, mock_get_connection):
@@ -229,7 +230,7 @@ def test_https_connection(self, mock_get_connection):
mock_get_connection.return_value = conn
hook = HttpHook()
hook.get_conn({})
- self.assertEqual(hook.base_url, 'https://localhost')
+ assert hook.base_url == 'https://localhost'
@mock.patch('airflow.providers.http.hooks.http.HttpHook.get_connection')
def test_host_encoded_http_connection(self, mock_get_connection):
@@ -237,7 +238,7 @@ def test_host_encoded_http_connection(self, mock_get_connection):
mock_get_connection.return_value = conn
hook = HttpHook()
hook.get_conn({})
- self.assertEqual(hook.base_url, 'http://localhost')
+ assert hook.base_url == 'http://localhost'
@mock.patch('airflow.providers.http.hooks.http.HttpHook.get_connection')
def test_host_encoded_https_connection(self, mock_get_connection):
@@ -245,10 +246,10 @@ def test_host_encoded_https_connection(self, mock_get_connection):
mock_get_connection.return_value = conn
hook = HttpHook()
hook.get_conn({})
- self.assertEqual(hook.base_url, 'https://localhost')
+ assert hook.base_url == 'https://localhost'
def test_method_converted_to_uppercase_when_created_in_lowercase(self):
- self.assertEqual(self.get_lowercase_hook.method, 'GET')
+ assert self.get_lowercase_hook.method == 'GET'
@mock.patch('airflow.providers.http.hooks.http.HttpHook.get_connection')
def test_connection_without_host(self, mock_get_connection):
@@ -257,7 +258,7 @@ def test_connection_without_host(self, mock_get_connection):
hook = HttpHook()
hook.get_conn({})
- self.assertEqual(hook.base_url, 'http://')
+ assert hook.base_url == 'http://'
@parameterized.expand(
[
diff --git a/tests/providers/http/operators/test_http.py b/tests/providers/http/operators/test_http.py
index 9704311d4f3d7..d4c622c29e6a5 100644
--- a/tests/providers/http/operators/test_http.py
+++ b/tests/providers/http/operators/test_http.py
@@ -19,6 +19,7 @@
import unittest
from unittest import mock
+import pytest
import requests_mock
from airflow.exceptions import AirflowException
@@ -69,7 +70,8 @@ def response_check(response):
)
with mock.patch.object(operator.log, 'info') as mock_info:
- self.assertRaises(AirflowException, operator.execute, {})
+ with pytest.raises(AirflowException):
+ operator.execute({})
calls = [mock.call('Calling HTTP method'), mock.call('invalid response')]
mock_info.assert_has_calls(calls, any_order=True)
diff --git a/tests/providers/http/sensors/test_http.py b/tests/providers/http/sensors/test_http.py
index 04b6c4946c758..b14309fbd603e 100644
--- a/tests/providers/http/sensors/test_http.py
+++ b/tests/providers/http/sensors/test_http.py
@@ -19,6 +19,7 @@
from unittest import mock
from unittest.mock import patch
+import pytest
import requests
from airflow.exceptions import AirflowException, AirflowSensorTimeout
@@ -59,7 +60,7 @@ def resp_check(_):
timeout=5,
poke_interval=1,
)
- with self.assertRaisesRegex(AirflowException, 'AirflowException raised here!'):
+ with pytest.raises(AirflowException, match='AirflowException raised here!'):
task.execute(context={})
@patch("airflow.providers.http.hooks.http.requests.Session.send")
@@ -86,8 +87,8 @@ def resp_check(_):
prep_request = requests.Request('HEAD', 'https://www.httpbin.org', {}).prepare()
- self.assertEqual(prep_request.url, received_request.url)
- self.assertTrue(prep_request.method, received_request.method)
+ assert prep_request.url == received_request.url
+ assert prep_request.method, received_request.method
@patch("airflow.providers.http.hooks.http.requests.Session.send")
def test_poke_context(self, mock_session_send):
@@ -138,10 +139,10 @@ def resp_check(_):
)
with mock.patch.object(task.hook.log, 'error') as mock_errors:
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task.execute(None)
- self.assertTrue(mock_errors.called)
+ assert mock_errors.called
calls = [
mock.call('HTTP error: %s', 'Not Found'),
mock.call('This endpoint doesnt exist'),
diff --git a/tests/providers/imap/hooks/test_imap.py b/tests/providers/imap/hooks/test_imap.py
index b37a63ecbd7e1..671d525461501 100644
--- a/tests/providers/imap/hooks/test_imap.py
+++ b/tests/providers/imap/hooks/test_imap.py
@@ -20,6 +20,8 @@
import unittest
from unittest.mock import Mock, mock_open, patch
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.imap.hooks.imap import ImapHook
@@ -81,7 +83,7 @@ def test_has_mail_attachment_found(self, mock_imaplib):
with ImapHook() as imap_hook:
has_attachment_in_inbox = imap_hook.has_mail_attachment('test1.csv')
- self.assertTrue(has_attachment_in_inbox)
+ assert has_attachment_in_inbox
@patch(imaplib_string)
def test_has_mail_attachment_not_found(self, mock_imaplib):
@@ -90,7 +92,7 @@ def test_has_mail_attachment_not_found(self, mock_imaplib):
with ImapHook() as imap_hook:
has_attachment_in_inbox = imap_hook.has_mail_attachment('test1.txt')
- self.assertFalse(has_attachment_in_inbox)
+ assert not has_attachment_in_inbox
@patch(imaplib_string)
def test_has_mail_attachment_with_regex_found(self, mock_imaplib):
@@ -99,7 +101,7 @@ def test_has_mail_attachment_with_regex_found(self, mock_imaplib):
with ImapHook() as imap_hook:
has_attachment_in_inbox = imap_hook.has_mail_attachment(name=r'test(\d+).csv', check_regex=True)
- self.assertTrue(has_attachment_in_inbox)
+ assert has_attachment_in_inbox
@patch(imaplib_string)
def test_has_mail_attachment_with_regex_not_found(self, mock_imaplib):
@@ -108,7 +110,7 @@ def test_has_mail_attachment_with_regex_not_found(self, mock_imaplib):
with ImapHook() as imap_hook:
has_attachment_in_inbox = imap_hook.has_mail_attachment(name=r'test_(\d+).csv', check_regex=True)
- self.assertFalse(has_attachment_in_inbox)
+ assert not has_attachment_in_inbox
@patch(imaplib_string)
def test_has_mail_attachment_with_mail_filter(self, mock_imaplib):
@@ -127,14 +129,15 @@ def test_retrieve_mail_attachments_found(self, mock_imaplib):
with ImapHook() as imap_hook:
attachments_in_inbox = imap_hook.retrieve_mail_attachments('test1.csv')
- self.assertEqual(attachments_in_inbox, [('test1.csv', b'SWQsTmFtZQoxLEZlbGl4')])
+ assert attachments_in_inbox == [('test1.csv', b'SWQsTmFtZQoxLEZlbGl4')]
@patch(imaplib_string)
def test_retrieve_mail_attachments_not_found(self, mock_imaplib):
_create_fake_imap(mock_imaplib, with_mail=True)
with ImapHook() as imap_hook:
- self.assertRaises(AirflowException, imap_hook.retrieve_mail_attachments, 'test1.txt')
+ with pytest.raises(AirflowException):
+ imap_hook.retrieve_mail_attachments('test1.txt')
@patch(imaplib_string)
def test_retrieve_mail_attachments_with_regex_found(self, mock_imaplib):
@@ -145,19 +148,18 @@ def test_retrieve_mail_attachments_with_regex_found(self, mock_imaplib):
name=r'test(\d+).csv', check_regex=True
)
- self.assertEqual(attachments_in_inbox, [('test1.csv', b'SWQsTmFtZQoxLEZlbGl4')])
+ assert attachments_in_inbox == [('test1.csv', b'SWQsTmFtZQoxLEZlbGl4')]
@patch(imaplib_string)
def test_retrieve_mail_attachments_with_regex_not_found(self, mock_imaplib):
_create_fake_imap(mock_imaplib, with_mail=True)
with ImapHook() as imap_hook:
- self.assertRaises(
- AirflowException,
- imap_hook.retrieve_mail_attachments,
- name=r'test_(\d+).csv',
- check_regex=True,
- )
+ with pytest.raises(AirflowException):
+ imap_hook.retrieve_mail_attachments(
+ name=r'test_(\d+).csv',
+ check_regex=True,
+ )
@patch(imaplib_string)
def test_retrieve_mail_attachments_latest_only(self, mock_imaplib):
@@ -166,7 +168,7 @@ def test_retrieve_mail_attachments_latest_only(self, mock_imaplib):
with ImapHook() as imap_hook:
attachments_in_inbox = imap_hook.retrieve_mail_attachments(name='test1.csv', latest_only=True)
- self.assertEqual(attachments_in_inbox, [('test1.csv', b'SWQsTmFtZQoxLEZlbGl4')])
+ assert attachments_in_inbox == [('test1.csv', b'SWQsTmFtZQoxLEZlbGl4')]
@patch(imaplib_string)
def test_retrieve_mail_attachments_with_mail_filter(self, mock_imaplib):
@@ -195,9 +197,8 @@ def test_download_mail_attachments_not_found(self, mock_imaplib, mock_open_metho
_create_fake_imap(mock_imaplib, with_mail=True)
with ImapHook() as imap_hook:
- self.assertRaises(
- AirflowException, imap_hook.download_mail_attachments, 'test1.txt', 'test_directory'
- )
+ with pytest.raises(AirflowException):
+ imap_hook.download_mail_attachments('test1.txt', 'test_directory')
mock_open_method.assert_not_called()
mock_open_method.return_value.write.assert_not_called()
@@ -221,13 +222,12 @@ def test_download_mail_attachments_with_regex_not_found(self, mock_imaplib, mock
_create_fake_imap(mock_imaplib, with_mail=True)
with ImapHook() as imap_hook:
- self.assertRaises(
- AirflowException,
- imap_hook.download_mail_attachments,
- name=r'test_(\d+).csv',
- local_output_directory='test_directory',
- check_regex=True,
- )
+ with pytest.raises(AirflowException):
+ imap_hook.download_mail_attachments(
+ name=r'test_(\d+).csv',
+ local_output_directory='test_directory',
+ check_regex=True,
+ )
mock_open_method.assert_not_called()
mock_open_method.return_value.write.assert_not_called()
diff --git a/tests/providers/imap/sensors/test_imap_attachment.py b/tests/providers/imap/sensors/test_imap_attachment.py
index 5af9679d3ea06..6fe22b67165ca 100644
--- a/tests/providers/imap/sensors/test_imap_attachment.py
+++ b/tests/providers/imap/sensors/test_imap_attachment.py
@@ -43,7 +43,7 @@ def test_poke(self, has_attachment_return_value, mock_imap_hook):
has_attachment = ImapAttachmentSensor(**self.kwargs).poke(context={})
- self.assertEqual(has_attachment, mock_imap_hook.has_mail_attachment.return_value)
+ assert has_attachment == mock_imap_hook.has_mail_attachment.return_value
mock_imap_hook.has_mail_attachment.assert_called_once_with(
name=self.kwargs['attachment_name'],
check_regex=self.kwargs['check_regex'],
diff --git a/tests/providers/jdbc/hooks/test_jdbc.py b/tests/providers/jdbc/hooks/test_jdbc.py
index e0585c53b4795..091fe56060978 100644
--- a/tests/providers/jdbc/hooks/test_jdbc.py
+++ b/tests/providers/jdbc/hooks/test_jdbc.py
@@ -49,9 +49,9 @@ def setUp(self):
def test_jdbc_conn_connection(self, jdbc_mock):
jdbc_hook = JdbcHook()
jdbc_conn = jdbc_hook.get_conn()
- self.assertTrue(jdbc_mock.called)
- self.assertIsInstance(jdbc_conn, Mock)
- self.assertEqual(jdbc_conn.name, jdbc_mock.return_value.name) # pylint: disable=no-member
+ assert jdbc_mock.called
+ assert isinstance(jdbc_conn, Mock)
+ assert jdbc_conn.name == jdbc_mock.return_value.name # pylint: disable=no-member
@patch("airflow.providers.jdbc.hooks.jdbc.jaydebeapi.connect")
def test_jdbc_conn_set_autocommit(self, _):
diff --git a/tests/providers/jenkins/hooks/test_jenkins.py b/tests/providers/jenkins/hooks/test_jenkins.py
index ebcdeeda49cb3..71fb45c7cb89f 100644
--- a/tests/providers/jenkins/hooks/test_jenkins.py
+++ b/tests/providers/jenkins/hooks/test_jenkins.py
@@ -41,8 +41,8 @@ def test_client_created_default_http(self, get_connection_mock):
complete_url = f'http://{connection_host}:{connection_port}/'
hook = JenkinsHook(default_connection_id)
- self.assertIsNotNone(hook.jenkins_server)
- self.assertEqual(hook.jenkins_server.server, complete_url)
+ assert hook.jenkins_server is not None
+ assert hook.jenkins_server.server == complete_url
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
def test_client_created_default_https(self, get_connection_mock):
@@ -63,5 +63,5 @@ def test_client_created_default_https(self, get_connection_mock):
complete_url = f'https://{connection_host}:{connection_port}/'
hook = JenkinsHook(default_connection_id)
- self.assertIsNotNone(hook.jenkins_server)
- self.assertEqual(hook.jenkins_server.server, complete_url)
+ assert hook.jenkins_server is not None
+ assert hook.jenkins_server.server == complete_url
diff --git a/tests/providers/jenkins/operators/test_jenkins_job_trigger.py b/tests/providers/jenkins/operators/test_jenkins_job_trigger.py
index a3dad65ef5431..038d780ab55d8 100644
--- a/tests/providers/jenkins/operators/test_jenkins_job_trigger.py
+++ b/tests/providers/jenkins/operators/test_jenkins_job_trigger.py
@@ -20,6 +20,7 @@
from unittest.mock import Mock, patch
import jenkins
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -75,7 +76,7 @@ def test_execute(self, _, parameters):
operator.execute(None)
- self.assertEqual(jenkins_mock.get_build_info.call_count, 1)
+ assert jenkins_mock.get_build_info.call_count == 1
jenkins_mock.get_build_info.assert_called_once_with(name='a_job_on_jenkins', number='1')
@parameterized.expand(
@@ -125,7 +126,7 @@ def test_execute_job_polling_loop(self, _, parameters):
)
operator.execute(None)
- self.assertEqual(jenkins_mock.get_build_info.call_count, 2)
+ assert jenkins_mock.get_build_info.call_count == 2
@parameterized.expand(
[
@@ -173,7 +174,8 @@ def test_execute_job_failure(self, _, parameters):
sleep_time=1,
)
- self.assertRaises(AirflowException, operator.execute, None)
+ with pytest.raises(AirflowException):
+ operator.execute(None)
def test_build_job_request_settings(self):
jenkins_mock = Mock(spec=jenkins.Jenkins, auth='secret', timeout=2)
@@ -191,5 +193,5 @@ def test_build_job_request_settings(self):
operator.build_job(jenkins_mock)
mock_request = mock_make_request.call_args_list[0][0][1]
- self.assertEqual(mock_request.method, 'POST')
- self.assertEqual(mock_request.url, 'http://apache.org')
+ assert mock_request.method == 'POST'
+ assert mock_request.url == 'http://apache.org'
diff --git a/tests/providers/jira/hooks/test_jira.py b/tests/providers/jira/hooks/test_jira.py
index 06c50c551e106..d511ebb6820e9 100644
--- a/tests/providers/jira/hooks/test_jira.py
+++ b/tests/providers/jira/hooks/test_jira.py
@@ -43,6 +43,6 @@ def setUp(self):
def test_jira_client_connection(self, jira_mock):
jira_hook = JiraHook()
- self.assertTrue(jira_mock.called)
- self.assertIsInstance(jira_hook.client, Mock)
- self.assertEqual(jira_hook.client.name, jira_mock.return_value.name) # pylint: disable=no-member
+ assert jira_mock.called
+ assert isinstance(jira_hook.client, Mock)
+ assert jira_hook.client.name == jira_mock.return_value.name # pylint: disable=no-member
diff --git a/tests/providers/jira/operators/test_jira.py b/tests/providers/jira/operators/test_jira.py
index a1b011e19a4d4..c668d51a88ac1 100644
--- a/tests/providers/jira/operators/test_jira.py
+++ b/tests/providers/jira/operators/test_jira.py
@@ -68,8 +68,8 @@ def test_issue_search(self, jira_mock):
jira_ticket_search_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
- self.assertTrue(jira_mock.called)
- self.assertTrue(jira_mock.return_value.search_issues.called)
+ assert jira_mock.called
+ assert jira_mock.return_value.search_issues.called
@patch("airflow.providers.jira.hooks.jira.JIRA", autospec=True, return_value=jira_client_mock)
def test_update_issue(self, jira_mock):
@@ -84,5 +84,5 @@ def test_update_issue(self, jira_mock):
add_comment_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
- self.assertTrue(jira_mock.called)
- self.assertTrue(jira_mock.return_value.add_comment.called)
+ assert jira_mock.called
+ assert jira_mock.return_value.add_comment.called
diff --git a/tests/providers/jira/sensors/test_jira.py b/tests/providers/jira/sensors/test_jira.py
index 48758dd84dd47..4d4a54b8ef45d 100644
--- a/tests/providers/jira/sensors/test_jira.py
+++ b/tests/providers/jira/sensors/test_jira.py
@@ -70,8 +70,8 @@ def test_issue_label_set(self, jira_mock):
ticket_label_sensor.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
- self.assertTrue(jira_mock.called)
- self.assertTrue(jira_mock.return_value.issue.called)
+ assert jira_mock.called
+ assert jira_mock.return_value.issue.called
@staticmethod
def field_checker_func(context, issue): # pylint: disable=unused-argument
diff --git a/tests/providers/microsoft/azure/hooks/test_adx.py b/tests/providers/microsoft/azure/hooks/test_adx.py
index e3a3ab3dbd29f..9da4ce1f6c827 100644
--- a/tests/providers/microsoft/azure/hooks/test_adx.py
+++ b/tests/providers/microsoft/azure/hooks/test_adx.py
@@ -21,6 +21,7 @@
import unittest
from unittest import mock
+import pytest
from azure.kusto.data.request import ClientRequestProperties, KustoClient, KustoConnectionStringBuilder
from airflow.exceptions import AirflowException
@@ -49,9 +50,9 @@ def test_conn_missing_method(self):
extra=json.dumps({}),
)
)
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
- self.assertIn('missing required parameter: `auth_method`', str(e.exception))
+ assert 'missing required parameter: `auth_method`' in str(ctx.value)
def test_conn_unknown_method(self):
db.merge_conn(
@@ -64,9 +65,9 @@ def test_conn_unknown_method(self):
extra=json.dumps({'auth_method': 'AAD_OTHER'}),
)
)
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
- self.assertIn('Unknown authentication method: AAD_OTHER', str(e.exception))
+ assert 'Unknown authentication method: AAD_OTHER' in str(ctx.value)
def test_conn_missing_cluster(self):
db.merge_conn(
@@ -78,9 +79,9 @@ def test_conn_missing_cluster(self):
extra=json.dumps({}),
)
)
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
- self.assertIn('Host connection option is required', str(e.exception))
+ assert 'Host connection option is required' in str(ctx.value)
@mock.patch.object(KustoClient, '__init__')
def test_conn_method_aad_creds(self, mock_init):
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_batch.py b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
index 9549e5795a2e5..4361170b18fda 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_batch.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_batch.py
@@ -82,8 +82,8 @@ def setUp(self):
def test_connection_and_client(self):
hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
- self.assertIsInstance(hook._connection(), Connection)
- self.assertIsInstance(hook.get_conn(), BatchServiceClient)
+ assert isinstance(hook._connection(), Connection)
+ assert isinstance(hook.get_conn(), BatchServiceClient)
def test_configure_pool_with_vm_config(self):
hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
@@ -95,7 +95,7 @@ def test_configure_pool_with_vm_config(self):
vm_offer="test.vm.offer",
sku_starts_with="test-sku",
)
- self.assertIsInstance(pool, batch_models.PoolAddParameter)
+ assert isinstance(pool, batch_models.PoolAddParameter)
def test_configure_pool_with_cloud_config(self):
hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
@@ -107,7 +107,7 @@ def test_configure_pool_with_cloud_config(self):
vm_offer="test.vm.offer",
sku_starts_with="test-sku",
)
- self.assertIsInstance(pool, batch_models.PoolAddParameter)
+ assert isinstance(pool, batch_models.PoolAddParameter)
def test_configure_pool_with_latest_vm(self):
with mock.patch(
@@ -125,7 +125,7 @@ def test_configure_pool_with_latest_vm(self):
vm_offer="test.vm.offer",
sku_starts_with="test-sku",
)
- self.assertIsInstance(pool, batch_models.PoolAddParameter)
+ assert isinstance(pool, batch_models.PoolAddParameter)
@mock.patch("airflow.providers.microsoft.azure.hooks.azure_batch.BatchServiceClient")
def test_create_pool_with_vm_config(self, mock_batch):
@@ -168,7 +168,7 @@ def test_job_configuration_and_create_job(self, mock_batch):
mock_instance = mock_batch.return_value.job.add
job = hook.configure_job(job_id='myjob', pool_id='mypool')
hook.create_job(job)
- self.assertIsInstance(job, batch_models.JobAddParameter)
+ assert isinstance(job, batch_models.JobAddParameter)
mock_instance.assert_called_once_with(job)
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_batch.BatchServiceClient')
@@ -177,7 +177,7 @@ def test_add_single_task_to_job(self, mock_batch):
mock_instance = mock_batch.return_value.task.add
task = hook.configure_task(task_id="mytask", command_line="echo hello")
hook.add_single_task_to_job(job_id='myjob', task=task)
- self.assertIsInstance(task, batch_models.TaskAddParameter)
+ assert isinstance(task, batch_models.TaskAddParameter)
mock_instance.assert_called_once_with(job_id="myjob", task=task)
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_batch.BatchServiceClient')
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
index a38c93ca26ec1..b744ed2d04918 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py
@@ -71,7 +71,7 @@ def test_get_logs(self, list_logs_mock):
logs = self.hook.get_logs('resource_group', 'name', 'name')
- self.assertSequenceEqual(logs, expected_messages)
+ assert logs == expected_messages
@patch('azure.mgmt.containerinstance.operations.ContainerGroupsOperations.delete')
def test_delete(self, delete_mock):
@@ -86,7 +86,7 @@ def test_exists_with_existing(self, list_mock):
containers=[Container(name='test1', image='hello-world', resources=self.resources)],
)
]
- self.assertFalse(self.hook.exists('test', 'test1'))
+ assert not self.hook.exists('test', 'test1')
@patch('azure.mgmt.containerinstance.operations.ContainerGroupsOperations.list_by_resource_group')
def test_exists_with_not_existing(self, list_mock):
@@ -96,4 +96,4 @@ def test_exists_with_not_existing(self, list_mock):
containers=[Container(name='test1', image='hello-world', resources=self.resources)],
)
]
- self.assertFalse(self.hook.exists('test', 'not found'))
+ assert not self.hook.exists('test', 'not found')
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
index 703b4ce5de8e3..7fc42113983d5 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
@@ -35,7 +35,7 @@ def test_get_conn(self):
)
)
hook = AzureContainerRegistryHook(conn_id='azure_container_registry')
- self.assertIsNotNone(hook.connection)
- self.assertEqual(hook.connection.username, 'myuser')
- self.assertEqual(hook.connection.password, 'password')
- self.assertEqual(hook.connection.server, 'test.cr')
+ assert hook.connection is not None
+ assert hook.connection.username == 'myuser'
+ assert hook.connection.password == 'password'
+ assert hook.connection.server == 'test.cr'
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
index ba5fb3720fff2..ed9d5e11782d5 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py
@@ -30,9 +30,9 @@ def test_get_file_volume(self):
volume = hook.get_file_volume(
mount_name='mount', share_name='share', storage_account_name='storage', read_only=True
)
- self.assertIsNotNone(volume)
- self.assertEqual(volume.name, 'mount')
- self.assertEqual(volume.azure_file.share_name, 'share')
- self.assertEqual(volume.azure_file.storage_account_key, 'key')
- self.assertEqual(volume.azure_file.storage_account_name, 'storage')
- self.assertEqual(volume.azure_file.read_only, True)
+ assert volume is not None
+ assert volume.name == 'mount'
+ assert volume.azure_file.share_name == 'share'
+ assert volume.azure_file.storage_account_key == 'key'
+ assert volume.azure_file.storage_account_name == 'storage'
+ assert volume.azure_file.read_only is True
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
index 701abc5a34143..1ff086d759f88 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
@@ -24,6 +24,7 @@
import uuid
from unittest import mock
+import pytest
from azure.cosmos.cosmos_client import CosmosClient
from airflow.exceptions import AirflowException
@@ -61,8 +62,8 @@ def setUp(self):
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient', autospec=True)
def test_client(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- self.assertIsNone(hook._conn)
- self.assertIsInstance(hook.get_conn(), CosmosClient)
+ assert hook._conn is None
+ assert isinstance(hook.get_conn(), CosmosClient)
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_create_database(self, mock_cosmos):
@@ -75,12 +76,14 @@ def test_create_database(self, mock_cosmos):
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_create_database_exception(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- self.assertRaises(AirflowException, hook.create_database, None)
+ with pytest.raises(AirflowException):
+ hook.create_database(None)
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_create_container_exception(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- self.assertRaises(AirflowException, hook.create_collection, None)
+ with pytest.raises(AirflowException):
+ hook.create_collection(None)
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_create_container(self, mock_cosmos):
@@ -117,7 +120,7 @@ def test_upsert_document_default(self, mock_cosmos):
mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
logging.getLogger().info(returned_item)
- self.assertEqual(returned_item['id'], test_id)
+ assert returned_item['id'] == test_id
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_upsert_document(self, mock_cosmos):
@@ -141,7 +144,7 @@ def test_upsert_document(self, mock_cosmos):
mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
logging.getLogger().info(returned_item)
- self.assertEqual(returned_item['id'], test_id)
+ assert returned_item['id'] == test_id
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_insert_documents(self, mock_cosmos):
@@ -185,12 +188,14 @@ def test_delete_database(self, mock_cosmos):
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_delete_database_exception(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- self.assertRaises(AirflowException, hook.delete_database, None)
+ with pytest.raises(AirflowException):
+ hook.delete_database(None)
@mock.patch('azure.cosmos.cosmos_client.CosmosClient')
def test_delete_container_exception(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- self.assertRaises(AirflowException, hook.delete_collection, None)
+ with pytest.raises(AirflowException):
+ hook.delete_collection(None)
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_delete_container(self, mock_cosmos):
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
index 8c70749299784..046f5565ae059 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py
@@ -43,9 +43,9 @@ def test_conn(self, mock_lib):
from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook
hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key')
- self.assertIsNone(hook._conn)
- self.assertEqual(hook.conn_id, 'adl_test_key')
- self.assertIsInstance(hook.get_conn(), core.AzureDLFileSystem)
+ assert hook._conn is None
+ assert hook.conn_id == 'adl_test_key'
+ assert isinstance(hook.get_conn(), core.AzureDLFileSystem)
assert mock_lib.auth.called
@mock.patch(
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
index a9a6298ddf641..72ff6777fe7e5 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py
@@ -52,23 +52,23 @@ def test_key_and_connection(self):
from azure.storage.file import FileService
hook = AzureFileShareHook(wasb_conn_id='wasb_test_key')
- self.assertEqual(hook.conn_id, 'wasb_test_key')
- self.assertIsNone(hook._conn)
- self.assertIsInstance(hook.get_conn(), FileService)
+ assert hook.conn_id == 'wasb_test_key'
+ assert hook._conn is None
+ assert isinstance(hook.get_conn(), FileService)
def test_sas_token(self):
from azure.storage.file import FileService
hook = AzureFileShareHook(wasb_conn_id='wasb_test_sas_token')
- self.assertEqual(hook.conn_id, 'wasb_test_sas_token')
- self.assertIsInstance(hook.get_conn(), FileService)
+ assert hook.conn_id == 'wasb_test_sas_token'
+ assert isinstance(hook.get_conn(), FileService)
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True)
def test_check_for_file(self, mock_service):
mock_instance = mock_service.return_value
mock_instance.exists.return_value = True
hook = AzureFileShareHook(wasb_conn_id='wasb_test_sas_token')
- self.assertTrue(hook.check_for_file('share', 'directory', 'file', timeout=3))
+ assert hook.check_for_file('share', 'directory', 'file', timeout=3)
mock_instance.exists.assert_called_once_with('share', 'directory', 'file', timeout=3)
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True)
@@ -76,7 +76,7 @@ def test_check_for_directory(self, mock_service):
mock_instance = mock_service.return_value
mock_instance.exists.return_value = True
hook = AzureFileShareHook(wasb_conn_id='wasb_test_sas_token')
- self.assertTrue(hook.check_for_directory('share', 'directory', timeout=3))
+ assert hook.check_for_directory('share', 'directory', timeout=3)
mock_instance.exists.assert_called_once_with('share', 'directory', timeout=3)
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True)
@@ -124,7 +124,7 @@ def test_list_files(self, mock_service):
]
hook = AzureFileShareHook(wasb_conn_id='wasb_test_sas_token')
files = hook.list_files('share', 'directory', timeout=1)
- self.assertEqual(files, ["file1", 'file2'])
+ assert files == ["file1", 'file2']
mock_instance.list_directories_and_files.assert_called_once_with('share', 'directory', timeout=1)
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True)
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py
index 5d092445e101f..16556a1b463ec 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -58,34 +58,34 @@ def setUp(self):
def test_key(self):
hook = WasbHook(wasb_conn_id='wasb_test_key')
- self.assertEqual(hook.conn_id, 'wasb_test_key')
- self.assertIsInstance(hook.connection, BlockBlobService)
+ assert hook.conn_id == 'wasb_test_key'
+ assert isinstance(hook.connection, BlockBlobService)
def test_sas_token(self):
hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
- self.assertEqual(hook.conn_id, 'wasb_test_sas_token')
- self.assertIsInstance(hook.connection, BlockBlobService)
+ assert hook.conn_id == 'wasb_test_sas_token'
+ assert isinstance(hook.connection, BlockBlobService)
@mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
def test_check_for_blob(self, mock_service):
mock_instance = mock_service.return_value
mock_instance.exists.return_value = True
hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
- self.assertTrue(hook.check_for_blob('container', 'blob', timeout=3))
+ assert hook.check_for_blob('container', 'blob', timeout=3)
mock_instance.exists.assert_called_once_with('container', 'blob', timeout=3)
@mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
def test_check_for_blob_empty(self, mock_service):
mock_service.return_value.exists.return_value = False
hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
- self.assertFalse(hook.check_for_blob('container', 'blob'))
+ assert not hook.check_for_blob('container', 'blob')
@mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
def test_check_for_prefix(self, mock_service):
mock_instance = mock_service.return_value
mock_instance.list_blobs.return_value = iter(['blob_1'])
hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
- self.assertTrue(hook.check_for_prefix('container', 'prefix', timeout=3))
+ assert hook.check_for_prefix('container', 'prefix', timeout=3)
mock_instance.list_blobs.assert_called_once_with('container', 'prefix', num_results=1, timeout=3)
@mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
@@ -93,7 +93,7 @@ def test_check_for_prefix_empty(self, mock_service):
mock_instance = mock_service.return_value
mock_instance.list_blobs.return_value = iter([])
hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
- self.assertFalse(hook.check_for_prefix('container', 'prefix'))
+ assert not hook.check_for_prefix('container', 'prefix')
@mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
def test_load_file(self, mock_service):
@@ -153,18 +153,18 @@ def test_delete_nonexisting_blob_fails(self, mock_service):
mock_instance = mock_service.return_value
mock_instance.exists.return_value = False
hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
- with self.assertRaises(Exception) as context:
+ with pytest.raises(Exception) as ctx:
hook.delete_file('container', 'nonexisting_blob', is_prefix=False, ignore_if_missing=False)
- self.assertIsInstance(context.exception, AirflowException)
+ assert isinstance(ctx.value, AirflowException)
@mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
def test_delete_multiple_nonexisting_blobs_fails(self, mock_service):
mock_instance = mock_service.return_value
mock_instance.list_blobs.return_value = iter([])
hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
- with self.assertRaises(Exception) as context:
+ with pytest.raises(Exception) as ctx:
hook.delete_file('container', 'nonexisting_blob_prefix', is_prefix=True, ignore_if_missing=False)
- self.assertIsInstance(context.exception, AirflowException)
+ assert isinstance(ctx.value, AirflowException)
@mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
def test_get_blobs_list(self, mock_service):
diff --git a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py
index df75249841e7d..f85964995f4b8 100644
--- a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py
+++ b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py
@@ -57,7 +57,7 @@ def setUp(self):
@conf_vars({('logging', 'remote_log_conn_id'): 'wasb_default'})
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService")
def test_hook(self, mock_service):
- self.assertIsInstance(self.wasb_task_handler.hook, WasbHook)
+ assert isinstance(self.wasb_task_handler.hook, WasbHook)
@conf_vars({('logging', 'remote_log_conn_id'): 'wasb_default'})
def test_hook_raises(self):
@@ -80,11 +80,11 @@ def test_hook_raises(self):
def test_set_context_raw(self):
self.ti.raw = True
self.wasb_task_handler.set_context(self.ti)
- self.assertFalse(self.wasb_task_handler.upload_on_close)
+ assert not self.wasb_task_handler.upload_on_close
def test_set_context_not_raw(self):
self.wasb_task_handler.set_context(self.ti)
- self.assertTrue(self.wasb_task_handler.upload_on_close)
+ assert self.wasb_task_handler.upload_on_close
# The `azure` provider uses legacy `azure-storage` library, where `snowflake` uses the
# newer and more stable versions of those libraries. Most of `azure` operators and hooks work
@@ -114,21 +114,18 @@ def test_wasb_log_exists(self, mock_hook):
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook")
def test_wasb_read(self, mock_hook):
mock_hook.return_value.read_file.return_value = 'Log line'
- self.assertEqual(self.wasb_task_handler.wasb_read(self.remote_log_location), "Log line")
- self.assertEqual(
- self.wasb_task_handler.read(self.ti),
- (
+ assert self.wasb_task_handler.wasb_read(self.remote_log_location) == "Log line"
+ assert self.wasb_task_handler.read(self.ti) == (
+ [
[
- [
- (
- '',
- '*** Reading remote log from wasb://container/remote/log/location/1.log.\n'
- 'Log line\n',
- )
- ]
- ],
- [{'end_of_log': True}],
- ),
+ (
+ '',
+ '*** Reading remote log from wasb://container/remote/log/location/1.log.\n'
+ 'Log line\n',
+ )
+ ]
+ ],
+ [{'end_of_log': True}],
)
def test_wasb_read_raises(self):
diff --git a/tests/providers/microsoft/azure/operators/test_adls_list.py b/tests/providers/microsoft/azure/operators/test_adls_list.py
index 9b0a5c2f7ed09..550704cbe0c87 100644
--- a/tests/providers/microsoft/azure/operators/test_adls_list.py
+++ b/tests/providers/microsoft/azure/operators/test_adls_list.py
@@ -41,4 +41,4 @@ def test_execute(self, mock_hook):
files = operator.execute(None)
mock_hook.return_value.list.assert_called_once_with(path=TEST_PATH)
- self.assertEqual(sorted(files), sorted(MOCK_FILES))
+ assert sorted(files) == sorted(MOCK_FILES)
diff --git a/tests/providers/microsoft/azure/operators/test_adx.py b/tests/providers/microsoft/azure/operators/test_adx.py
index a425755b16fd3..d8b080b419a99 100644
--- a/tests/providers/microsoft/azure/operators/test_adx.py
+++ b/tests/providers/microsoft/azure/operators/test_adx.py
@@ -68,10 +68,10 @@ def setUp(self):
self.operator = AzureDataExplorerQueryOperator(dag=self.dag, **MOCK_DATA)
def test_init(self):
- self.assertEqual(self.operator.task_id, MOCK_DATA['task_id'])
- self.assertEqual(self.operator.query, MOCK_DATA['query'])
- self.assertEqual(self.operator.database, MOCK_DATA['database'])
- self.assertEqual(self.operator.azure_data_explorer_conn_id, 'azure_data_explorer_default')
+ assert self.operator.task_id == MOCK_DATA['task_id']
+ assert self.operator.query == MOCK_DATA['query']
+ assert self.operator.database == MOCK_DATA['database']
+ assert self.operator.azure_data_explorer_conn_id == 'azure_data_explorer_default'
@mock.patch.object(AzureDataExplorerHook, 'run_query', return_value=MockResponse())
@mock.patch.object(AzureDataExplorerHook, 'get_conn')
@@ -87,4 +87,4 @@ def test_xcom_push_and_pull(self, mock_conn, mock_run_query):
ti = TaskInstance(task=self.operator, execution_date=timezone.utcnow())
ti.run()
- self.assertEqual(ti.xcom_pull(task_ids=MOCK_DATA['task_id']), str(MOCK_RESULT))
+ assert ti.xcom_pull(task_ids=MOCK_DATA['task_id']) == str(MOCK_RESULT)
diff --git a/tests/providers/microsoft/azure/operators/test_azure_batch.py b/tests/providers/microsoft/azure/operators/test_azure_batch.py
index 02bac969e9f81..926872a7ffdc9 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_batch.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_batch.py
@@ -20,6 +20,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.azure_batch import AzureBatchHook
@@ -174,7 +176,7 @@ def setUp(self, mock_batch, mock_hook):
)
self.batch_client = mock_batch.return_value
self.mock_instance = mock_hook.return_value
- self.assertEqual(self.batch_client, self.operator.hook.connection)
+ assert self.batch_client == self.operator.hook.connection
@mock.patch.object(AzureBatchHook, 'wait_for_all_node_state')
def test_execute_without_failures(self, wait_mock):
@@ -201,7 +203,7 @@ def test_execute_with_failures(self, wait_mock):
self.operator.batch_pool_id = None
# test that it raises
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.operator.execute(None)
@mock.patch.object(AzureBatchHook, 'wait_for_all_node_state')
@@ -217,38 +219,36 @@ def test_execute_with_cleaning(self, mock_clean, wait_mock):
@mock.patch.object(AzureBatchHook, "wait_for_all_node_state")
def test_operator_fails(self, wait_mock):
wait_mock.return_value = True
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.operator_fail.execute(None)
- self.assertEqual(
- str(e.exception),
- "Either target_dedicated_nodes or enable_auto_scale must be set. None was set",
+ assert (
+ str(ctx.value) == "Either target_dedicated_nodes or enable_auto_scale must be set. None was set"
)
@mock.patch.object(AzureBatchHook, "wait_for_all_node_state")
def test_operator_fails_no_formula(self, wait_mock):
wait_mock.return_value = True
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.operator2_no_formula.execute(None)
- self.assertEqual(str(e.exception), "The auto_scale_formula is required when enable_auto_scale is set")
+ assert str(ctx.value) == "The auto_scale_formula is required when enable_auto_scale is set"
@mock.patch.object(AzureBatchHook, "wait_for_all_node_state")
def test_operator_fails_mutual_exclusive(self, wait_mock):
wait_mock.return_value = True
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.operator_mutual_exclusive.execute(None)
- self.assertEqual(
- str(e.exception),
- "Cloud service configuration and virtual machine configuration "
+ assert (
+ str(ctx.value) == "Cloud service configuration and virtual machine configuration "
"are mutually exclusive. You must specify either of os_family and"
- " vm_publisher",
+ " vm_publisher"
)
@mock.patch.object(AzureBatchHook, "wait_for_all_node_state")
def test_operator_fails_invalid_args(self, wait_mock):
wait_mock.return_value = True
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
self.operator_invalid.execute(None)
- self.assertEqual(str(e.exception), "You must specify either vm_publisher or os_family")
+ assert str(ctx.value) == "You must specify either vm_publisher or os_family"
def test_cleaning_works(self):
self.operator.clean_up(job_id="myjob")
diff --git a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
index 90f9ece7c3435..7d2118f8d0c6b 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py
@@ -22,6 +22,7 @@
from unittest import mock
from unittest.mock import MagicMock
+import pytest
from azure.mgmt.containerinstance.models import ContainerState, Event
from airflow.exceptions import AirflowException
@@ -69,22 +70,22 @@ def test_execute(self, aci_mock):
)
aci.execute(None)
- self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
+ assert aci_mock.return_value.create_or_update.call_count == 1
(called_rg, called_cn, called_cg), _ = aci_mock.return_value.create_or_update.call_args
- self.assertEqual(called_rg, 'resource-group')
- self.assertEqual(called_cn, 'container-name')
+ assert called_rg == 'resource-group'
+ assert called_cn == 'container-name'
- self.assertEqual(called_cg.location, 'region')
- self.assertEqual(called_cg.image_registry_credentials, None)
- self.assertEqual(called_cg.restart_policy, 'Never')
- self.assertEqual(called_cg.os_type, 'Linux')
+ assert called_cg.location == 'region'
+ assert called_cg.image_registry_credentials is None
+ assert called_cg.restart_policy == 'Never'
+ assert called_cg.os_type == 'Linux'
called_cg_container = called_cg.containers[0]
- self.assertEqual(called_cg_container.name, 'container-name')
- self.assertEqual(called_cg_container.image, 'container-image')
+ assert called_cg_container.name == 'container-name'
+ assert called_cg_container.image == 'container-image'
- self.assertEqual(aci_mock.return_value.delete.call_count, 1)
+ assert aci_mock.return_value.delete.call_count == 1
@mock.patch(
"airflow.providers.microsoft.azure.operators." "azure_container_instances.AzureContainerInstanceHook"
@@ -105,10 +106,10 @@ def test_execute_with_failures(self, aci_mock):
region='region',
task_id='task',
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
aci.execute(None)
- self.assertEqual(aci_mock.return_value.delete.call_count, 1)
+ assert aci_mock.return_value.delete.call_count == 1
@mock.patch(
"airflow.providers.microsoft.azure.operators." "azure_container_instances.AzureContainerInstanceHook"
@@ -133,23 +134,23 @@ def test_execute_with_tags(self, aci_mock):
)
aci.execute(None)
- self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
+ assert aci_mock.return_value.create_or_update.call_count == 1
(called_rg, called_cn, called_cg), _ = aci_mock.return_value.create_or_update.call_args
- self.assertEqual(called_rg, 'resource-group')
- self.assertEqual(called_cn, 'container-name')
+ assert called_rg == 'resource-group'
+ assert called_cn == 'container-name'
- self.assertEqual(called_cg.location, 'region')
- self.assertEqual(called_cg.image_registry_credentials, None)
- self.assertEqual(called_cg.restart_policy, 'Never')
- self.assertEqual(called_cg.os_type, 'Linux')
- self.assertEqual(called_cg.tags, tags)
+ assert called_cg.location == 'region'
+ assert called_cg.image_registry_credentials is None
+ assert called_cg.restart_policy == 'Never'
+ assert called_cg.os_type == 'Linux'
+ assert called_cg.tags == tags
called_cg_container = called_cg.containers[0]
- self.assertEqual(called_cg_container.name, 'container-name')
- self.assertEqual(called_cg_container.image, 'container-image')
+ assert called_cg_container.name == 'container-name'
+ assert called_cg_container.image == 'container-image'
- self.assertEqual(aci_mock.return_value.delete.call_count, 1)
+ assert aci_mock.return_value.delete.call_count == 1
@mock.patch(
"airflow.providers.microsoft.azure.operators." "azure_container_instances.AzureContainerInstanceHook"
@@ -176,11 +177,11 @@ def test_execute_with_messages_logs(self, aci_mock):
)
aci.execute(None)
- self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
- self.assertEqual(aci_mock.return_value.get_state.call_count, 2)
- self.assertEqual(aci_mock.return_value.get_logs.call_count, 2)
+ assert aci_mock.return_value.create_or_update.call_count == 1
+ assert aci_mock.return_value.get_state.call_count == 2
+ assert aci_mock.return_value.get_logs.call_count == 2
- self.assertEqual(aci_mock.return_value.delete.call_count, 1)
+ assert aci_mock.return_value.delete.call_count == 1
def test_name_checker(self):
valid_names = ['test-dash', 'name-with-length---63' * 3]
@@ -192,12 +193,12 @@ def test_name_checker(self):
'-name-starting-with-dash',
]
for name in invalid_names:
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
AzureContainerInstancesOperator._check_name(name)
for name in valid_names:
checked_name = AzureContainerInstancesOperator._check_name(name)
- self.assertEqual(checked_name, name)
+ assert checked_name == name
@mock.patch(
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
@@ -221,10 +222,10 @@ def test_execute_with_ipaddress(self, aci_mock):
ip_address=ipaddress,
)
aci.execute(None)
- self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
+ assert aci_mock.return_value.create_or_update.call_count == 1
(_, _, called_cg), _ = aci_mock.return_value.create_or_update.call_args
- self.assertEqual(called_cg.ip_address, ipaddress)
+ assert called_cg.ip_address == ipaddress
@mock.patch(
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
@@ -248,11 +249,11 @@ def test_execute_with_windows_os_and_diff_restart_policy(self, aci_mock):
os_type='Windows',
)
aci.execute(None)
- self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
+ assert aci_mock.return_value.create_or_update.call_count == 1
(_, _, called_cg), _ = aci_mock.return_value.create_or_update.call_args
- self.assertEqual(called_cg.restart_policy, 'Always')
- self.assertEqual(called_cg.os_type, 'Windows')
+ assert called_cg.restart_policy == 'Always'
+ assert called_cg.os_type == 'Windows'
@mock.patch(
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
@@ -264,7 +265,7 @@ def test_execute_fails_with_incorrect_os_type(self, aci_mock):
aci_mock.return_value.get_state.return_value = expected_cg
aci_mock.return_value.exists.return_value = False
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
AzureContainerInstancesOperator(
ci_conn_id=None,
registry_conn_id=None,
@@ -276,11 +277,10 @@ def test_execute_fails_with_incorrect_os_type(self, aci_mock):
os_type='MacOs',
)
- self.assertEqual(
- str(e.exception),
- "Invalid value for the os_type argument. "
+ assert (
+ str(ctx.value) == "Invalid value for the os_type argument. "
"Please set 'Linux' or 'Windows' as the os_type. "
- "Found `MacOs`.",
+ "Found `MacOs`."
)
@mock.patch(
@@ -293,7 +293,7 @@ def test_execute_fails_with_incorrect_restart_policy(self, aci_mock):
aci_mock.return_value.get_state.return_value = expected_cg
aci_mock.return_value.exists.return_value = False
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
AzureContainerInstancesOperator(
ci_conn_id=None,
registry_conn_id=None,
@@ -305,9 +305,8 @@ def test_execute_fails_with_incorrect_restart_policy(self, aci_mock):
restart_policy='Everyday',
)
- self.assertEqual(
- str(e.exception),
- "Invalid value for the restart_policy argument. "
+ assert (
+ str(ctx.value) == "Invalid value for the restart_policy argument. "
"Please set one of 'Always', 'OnFailure','Never' as the restart_policy. "
- "Found `Everyday`",
+ "Found `Everyday`"
)
diff --git a/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py b/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py
index 04df34cd89ca6..b007efa10e3a1 100644
--- a/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py
+++ b/tests/providers/microsoft/azure/operators/test_wasb_delete_blob.py
@@ -38,16 +38,16 @@ def setUp(self):
def test_init(self):
operator = WasbDeleteBlobOperator(task_id='wasb_operator_1', dag=self.dag, **self._config)
- self.assertEqual(operator.container_name, self._config['container_name'])
- self.assertEqual(operator.blob_name, self._config['blob_name'])
- self.assertEqual(operator.is_prefix, False)
- self.assertEqual(operator.ignore_if_missing, False)
+ assert operator.container_name == self._config['container_name']
+ assert operator.blob_name == self._config['blob_name']
+ assert operator.is_prefix is False
+ assert operator.ignore_if_missing is False
operator = WasbDeleteBlobOperator(
task_id='wasb_operator_2', dag=self.dag, is_prefix=True, ignore_if_missing=True, **self._config
)
- self.assertEqual(operator.is_prefix, True)
- self.assertEqual(operator.ignore_if_missing, True)
+ assert operator.is_prefix is True
+ assert operator.ignore_if_missing is True
@mock.patch('airflow.providers.microsoft.azure.operators.wasb_delete_blob.WasbHook', autospec=True)
def test_execute(self, mock_hook):
diff --git a/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py b/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
index e2cef7ad56eed..783dcf62f8b03 100644
--- a/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
+++ b/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
@@ -29,7 +29,7 @@ def test_get_connections(self, mock_get_uri):
mock_get_uri.return_value = 'scheme://user:pass@host:100'
conn_list = AzureKeyVaultBackend().get_connections('fake_conn')
conn = conn_list[0]
- self.assertEqual(conn.host, 'host')
+ assert conn.host == 'host'
@mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.DefaultAzureCredential')
@mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.SecretClient')
@@ -48,7 +48,7 @@ def test_get_conn_uri(self, mock_secret_client, mock_azure_cred):
mock_secret_client.assert_called_once_with(
credential=mock_cred, vault_url='https://example-akv-resource-name.vault.azure.net/'
)
- self.assertEqual(returned_uri, 'postgresql://airflow:airflow@host:5432/airflow')
+ assert returned_uri == 'postgresql://airflow:airflow@host:5432/airflow'
@mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend.client')
def test_get_conn_uri_non_existent_key(self, mock_client):
@@ -60,8 +60,8 @@ def test_get_conn_uri_non_existent_key(self, mock_client):
mock_client.get_secret.side_effect = ResourceNotFoundError
backend = AzureKeyVaultBackend(vault_url="https://example-akv-resource-name.vault.azure.net/")
- self.assertIsNone(backend.get_conn_uri(conn_id=conn_id))
- self.assertEqual([], backend.get_connections(conn_id=conn_id))
+ assert backend.get_conn_uri(conn_id=conn_id) is None
+ assert [] == backend.get_connections(conn_id=conn_id)
@mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend.client')
def test_get_variable(self, mock_client):
@@ -69,7 +69,7 @@ def test_get_variable(self, mock_client):
backend = AzureKeyVaultBackend()
returned_uri = backend.get_variable('hello')
mock_client.get_secret.assert_called_with(name='airflow-variables-hello')
- self.assertEqual('world', returned_uri)
+ assert 'world' == returned_uri
@mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend.client')
def test_get_variable_non_existent_key(self, mock_client):
@@ -79,7 +79,7 @@ def test_get_variable_non_existent_key(self, mock_client):
"""
mock_client.get_secret.side_effect = ResourceNotFoundError
backend = AzureKeyVaultBackend()
- self.assertIsNone(backend.get_variable('test_mysql'))
+ assert backend.get_variable('test_mysql') is None
@mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend.client')
def test_get_secret_value_not_found(self, mock_client):
@@ -88,8 +88,8 @@ def test_get_secret_value_not_found(self, mock_client):
"""
mock_client.get_secret.side_effect = ResourceNotFoundError
backend = AzureKeyVaultBackend()
- self.assertIsNone(
- backend._get_secret(path_prefix=backend.connections_prefix, secret_id='test_non_existent')
+ assert (
+ backend._get_secret(path_prefix=backend.connections_prefix, secret_id='test_non_existent') is None
)
@mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend.client')
@@ -101,7 +101,7 @@ def test_get_secret_value(self, mock_client):
backend = AzureKeyVaultBackend()
secret_val = backend._get_secret('af-secrets', 'test_mysql_password')
mock_client.get_secret.assert_called_with(name='af-secrets-test-mysql-password')
- self.assertEqual(secret_val, 'super-secret')
+ assert secret_val == 'super-secret'
@mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend._get_secret')
def test_connection_prefix_none_value(self, mock_get_secret):
@@ -113,7 +113,7 @@ def test_connection_prefix_none_value(self, mock_get_secret):
kwargs = {'connections_prefix': None}
backend = AzureKeyVaultBackend(**kwargs)
- self.assertIsNone(backend.get_conn_uri('test_mysql'))
+ assert backend.get_conn_uri('test_mysql') is None
mock_get_secret.assert_not_called()
@mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend._get_secret')
@@ -126,7 +126,7 @@ def test_variable_prefix_none_value(self, mock_get_secret):
kwargs = {'variables_prefix': None}
backend = AzureKeyVaultBackend(**kwargs)
- self.assertIsNone(backend.get_variable('hello'))
+ assert backend.get_variable('hello') is None
mock_get_secret.assert_not_called()
@mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend._get_secret')
@@ -139,5 +139,5 @@ def test_config_prefix_none_value(self, mock_get_secret):
kwargs = {'config_prefix': None}
backend = AzureKeyVaultBackend(**kwargs)
- self.assertIsNone(backend.get_config('test_mysql'))
+ assert backend.get_config('test_mysql') is None
mock_get_secret.assert_not_called()
diff --git a/tests/providers/microsoft/azure/sensors/test_azure_cosmos.py b/tests/providers/microsoft/azure/sensors/test_azure_cosmos.py
index c47202e481a4d..a2eafe7ad638f 100644
--- a/tests/providers/microsoft/azure/sensors/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/sensors/test_azure_cosmos.py
@@ -38,7 +38,7 @@ def test_should_call_hook_with_args(self, mock_hook):
)
result = sensor.poke(None)
mock_instance.get_document.assert_called_once_with(DOCUMENT_ID, DB_NAME, COLLECTION_NAME)
- self.assertEqual(result, True)
+ assert result is True
@mock.patch('airflow.providers.microsoft.azure.sensors.azure_cosmos.AzureCosmosDBHook')
def test_should_return_false_on_no_document(self, mock_hook):
@@ -52,4 +52,4 @@ def test_should_return_false_on_no_document(self, mock_hook):
)
result = sensor.poke(None)
mock_instance.get_document.assert_called_once_with(DOCUMENT_ID, DB_NAME, COLLECTION_NAME)
- self.assertEqual(result, False)
+ assert result is False
diff --git a/tests/providers/microsoft/azure/sensors/test_wasb.py b/tests/providers/microsoft/azure/sensors/test_wasb.py
index 5aaec19e5e203..9ec7ef19b0660 100644
--- a/tests/providers/microsoft/azure/sensors/test_wasb.py
+++ b/tests/providers/microsoft/azure/sensors/test_wasb.py
@@ -39,16 +39,16 @@ def setUp(self):
def test_init(self):
sensor = WasbBlobSensor(task_id='wasb_sensor_1', dag=self.dag, **self._config)
- self.assertEqual(sensor.container_name, self._config['container_name'])
- self.assertEqual(sensor.blob_name, self._config['blob_name'])
- self.assertEqual(sensor.wasb_conn_id, self._config['wasb_conn_id'])
- self.assertEqual(sensor.check_options, {})
- self.assertEqual(sensor.timeout, self._config['timeout'])
+ assert sensor.container_name == self._config['container_name']
+ assert sensor.blob_name == self._config['blob_name']
+ assert sensor.wasb_conn_id == self._config['wasb_conn_id']
+ assert sensor.check_options == {}
+ assert sensor.timeout == self._config['timeout']
sensor = WasbBlobSensor(
task_id='wasb_sensor_2', dag=self.dag, check_options={'timeout': 2}, **self._config
)
- self.assertEqual(sensor.check_options, {'timeout': 2})
+ assert sensor.check_options == {'timeout': 2}
@mock.patch('airflow.providers.microsoft.azure.sensors.wasb.WasbHook', autospec=True)
def test_poke(self, mock_hook):
@@ -74,16 +74,16 @@ def setUp(self):
def test_init(self):
sensor = WasbPrefixSensor(task_id='wasb_sensor_1', dag=self.dag, **self._config)
- self.assertEqual(sensor.container_name, self._config['container_name'])
- self.assertEqual(sensor.prefix, self._config['prefix'])
- self.assertEqual(sensor.wasb_conn_id, self._config['wasb_conn_id'])
- self.assertEqual(sensor.check_options, {})
- self.assertEqual(sensor.timeout, self._config['timeout'])
+ assert sensor.container_name == self._config['container_name']
+ assert sensor.prefix == self._config['prefix']
+ assert sensor.wasb_conn_id == self._config['wasb_conn_id']
+ assert sensor.check_options == {}
+ assert sensor.timeout == self._config['timeout']
sensor = WasbPrefixSensor(
task_id='wasb_sensor_2', dag=self.dag, check_options={'timeout': 2}, **self._config
)
- self.assertEqual(sensor.check_options, {'timeout': 2})
+ assert sensor.check_options == {'timeout': 2}
@mock.patch('airflow.providers.microsoft.azure.sensors.wasb.WasbHook', autospec=True)
def test_poke(self, mock_hook):
diff --git a/tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py b/tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py
index b9517a894c727..a1c62d24024f2 100644
--- a/tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py
+++ b/tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py
@@ -49,18 +49,18 @@ def test_init(self):
impersonation_chain=IMPERSONATION_CHAIN,
task_id=TASK_ID,
)
- self.assertEqual(operator.wasb_conn_id, WASB_CONN_ID)
- self.assertEqual(operator.blob_name, BLOB_NAME)
- self.assertEqual(operator.file_path, FILE_PATH)
- self.assertEqual(operator.container_name, CONTAINER_NAME)
- self.assertEqual(operator.gcp_conn_id, GCP_CONN_ID)
- self.assertEqual(operator.bucket_name, BUCKET_NAME)
- self.assertEqual(operator.object_name, OBJECT_NAME)
- self.assertEqual(operator.filename, FILENAME)
- self.assertEqual(operator.gzip, GZIP)
- self.assertEqual(operator.delegate_to, DELEGATE_TO)
- self.assertEqual(operator.impersonation_chain, IMPERSONATION_CHAIN)
- self.assertEqual(operator.task_id, TASK_ID)
+ assert operator.wasb_conn_id == WASB_CONN_ID
+ assert operator.blob_name == BLOB_NAME
+ assert operator.file_path == FILE_PATH
+ assert operator.container_name == CONTAINER_NAME
+ assert operator.gcp_conn_id == GCP_CONN_ID
+ assert operator.bucket_name == BUCKET_NAME
+ assert operator.object_name == OBJECT_NAME
+ assert operator.filename == FILENAME
+ assert operator.gzip == GZIP
+ assert operator.delegate_to == DELEGATE_TO
+ assert operator.impersonation_chain == IMPERSONATION_CHAIN
+ assert operator.task_id == TASK_ID
@mock.patch("airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs.WasbHook")
@mock.patch("airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs.GCSHook")
diff --git a/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py b/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py
index 5a4f14c729259..73deaaeed7dc6 100644
--- a/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py
+++ b/tests/providers/microsoft/azure/transfers/test_file_to_wasb.py
@@ -41,17 +41,17 @@ def setUp(self):
def test_init(self):
operator = FileToWasbOperator(task_id='wasb_operator_1', dag=self.dag, **self._config)
- self.assertEqual(operator.file_path, self._config['file_path'])
- self.assertEqual(operator.container_name, self._config['container_name'])
- self.assertEqual(operator.blob_name, self._config['blob_name'])
- self.assertEqual(operator.wasb_conn_id, self._config['wasb_conn_id'])
- self.assertEqual(operator.load_options, {})
- self.assertEqual(operator.retries, self._config['retries'])
+ assert operator.file_path == self._config['file_path']
+ assert operator.container_name == self._config['container_name']
+ assert operator.blob_name == self._config['blob_name']
+ assert operator.wasb_conn_id == self._config['wasb_conn_id']
+ assert operator.load_options == {}
+ assert operator.retries == self._config['retries']
operator = FileToWasbOperator(
task_id='wasb_operator_2', dag=self.dag, load_options={'timeout': 2}, **self._config
)
- self.assertEqual(operator.load_options, {'timeout': 2})
+ assert operator.load_options == {'timeout': 2}
@mock.patch('airflow.providers.microsoft.azure.transfers.file_to_wasb.WasbHook', autospec=True)
def test_execute(self, mock_hook):
diff --git a/tests/providers/microsoft/azure/transfers/test_local_to_adls.py b/tests/providers/microsoft/azure/transfers/test_local_to_adls.py
index 0bcc371393a35..b0de93f5ec241 100644
--- a/tests/providers/microsoft/azure/transfers/test_local_to_adls.py
+++ b/tests/providers/microsoft/azure/transfers/test_local_to_adls.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.transfers.local_to_adls import LocalToAzureDataLakeStorageOperator
@@ -49,9 +51,9 @@ def test_execute_raises_for_bad_glob_val(self, mock_hook):
operator = LocalToAzureDataLakeStorageOperator(
task_id=TASK_ID, local_path=BAD_LOCAL_PATH, remote_path=REMOTE_PATH
)
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
operator.execute(None)
- self.assertEqual(str(e.exception), "Recursive glob patterns using `**` are not supported")
+ assert str(ctx.value) == "Recursive glob patterns using `**` are not supported"
@mock.patch('airflow.providers.microsoft.azure.transfers.local_to_adls.AzureDataLakeHook')
def test_extra_options_is_passed(self, mock_hook):
diff --git a/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py b/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py
index 8d16878b2d45d..507e143a03774 100644
--- a/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py
+++ b/tests/providers/microsoft/azure/transfers/test_oracle_to_azure_data_lake.py
@@ -75,11 +75,11 @@ def test_write_temp_file(self):
rownum = 0
for row in temp_file:
if rownum == 0:
- self.assertEqual(row[0], 'id')
- self.assertEqual(row[1], 'description')
+ assert row[0] == 'id'
+ assert row[1] == 'description'
else:
- self.assertEqual(row[0], str(cursor_rows[rownum - 1][0]))
- self.assertEqual(row[1], cursor_rows[rownum - 1][1])
+ assert row[0] == str(cursor_rows[rownum - 1][0])
+ assert row[1] == cursor_rows[rownum - 1][1]
rownum = rownum + 1
@mock.patch(mock_module_path + '.OracleHook', autospec=True)
diff --git a/tests/providers/microsoft/mssql/hooks/test_mssql.py b/tests/providers/microsoft/mssql/hooks/test_mssql.py
index 406a74751d475..63e032406d03e 100644
--- a/tests/providers/microsoft/mssql/hooks/test_mssql.py
+++ b/tests/providers/microsoft/mssql/hooks/test_mssql.py
@@ -39,7 +39,7 @@ def test_get_conn_should_return_connection(self, get_connection, mssql_get_conn)
hook = MsSqlHook()
conn = hook.get_conn()
- self.assertEqual(mssql_get_conn.return_value, conn)
+ assert mssql_get_conn.return_value == conn
mssql_get_conn.assert_called_once()
@unittest.skipIf(PY38, "Mssql package not available when Python >= 3.8.")
@@ -69,4 +69,4 @@ def test_get_autocommit_should_return_autocommit_state(self, get_connection, mss
conn = hook.get_conn()
mssql_get_conn.assert_called_once()
- self.assertEqual(hook.get_autocommit(conn), 'autocommit_state')
+ assert hook.get_autocommit(conn) == 'autocommit_state'
diff --git a/tests/providers/microsoft/winrm/hooks/test_winrm.py b/tests/providers/microsoft/winrm/hooks/test_winrm.py
index 8b9643ba533e7..042a2cd5a7cd6 100644
--- a/tests/providers/microsoft/winrm/hooks/test_winrm.py
+++ b/tests/providers/microsoft/winrm/hooks/test_winrm.py
@@ -20,6 +20,8 @@
import unittest
from unittest.mock import patch
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook
@@ -33,17 +35,17 @@ def test_get_conn_exists(self, mock_protocol):
conn = winrm_hook.get_conn()
- self.assertEqual(conn, winrm_hook.client)
+ assert conn == winrm_hook.client
def test_get_conn_missing_remote_host(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
WinRMHook().get_conn()
@patch('airflow.providers.microsoft.winrm.hooks.winrm.Protocol')
def test_get_conn_error(self, mock_protocol):
mock_protocol.side_effect = Exception('Error')
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
WinRMHook(remote_host='host').get_conn()
@patch('airflow.providers.microsoft.winrm.hooks.winrm.Protocol', autospec=True)
@@ -107,7 +109,7 @@ def test_get_conn_no_username(self, mock_protocol, mock_getuser):
winrm_hook.get_conn()
- self.assertEqual(mock_getuser.return_value, winrm_hook.username)
+ assert mock_getuser.return_value == winrm_hook.username
@patch('airflow.providers.microsoft.winrm.hooks.winrm.Protocol')
def test_get_conn_no_endpoint(self, mock_protocol):
@@ -115,6 +117,4 @@ def test_get_conn_no_endpoint(self, mock_protocol):
winrm_hook.get_conn()
- self.assertEqual(
- f'http://{winrm_hook.remote_host}:{winrm_hook.remote_port}/wsman', winrm_hook.endpoint
- )
+ assert f'http://{winrm_hook.remote_host}:{winrm_hook.remote_port}/wsman' == winrm_hook.endpoint
diff --git a/tests/providers/microsoft/winrm/operators/test_winrm.py b/tests/providers/microsoft/winrm/operators/test_winrm.py
index ecc14651b2bc1..2e6426b16291a 100644
--- a/tests/providers/microsoft/winrm/operators/test_winrm.py
+++ b/tests/providers/microsoft/winrm/operators/test_winrm.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.microsoft.winrm.operators.winrm import WinRMOperator
@@ -27,12 +29,12 @@ class TestWinRMOperator(unittest.TestCase):
def test_no_winrm_hook_no_ssh_conn_id(self):
op = WinRMOperator(task_id='test_task_id', winrm_hook=None, ssh_conn_id=None)
exception_msg = "Cannot operate without winrm_hook or ssh_conn_id."
- with self.assertRaisesRegex(AirflowException, exception_msg):
+ with pytest.raises(AirflowException, match=exception_msg):
op.execute(None)
@mock.patch('airflow.providers.microsoft.winrm.operators.winrm.WinRMHook')
def test_no_command(self, mock_hook):
op = WinRMOperator(task_id='test_task_id', winrm_hook=mock_hook, command=None)
exception_msg = "No command specified so nothing to execute here."
- with self.assertRaisesRegex(AirflowException, exception_msg):
+ with pytest.raises(AirflowException, match=exception_msg):
op.execute(None)
diff --git a/tests/providers/mongo/hooks/test_mongo.py b/tests/providers/mongo/hooks/test_mongo.py
index 8b5fa95d64ae1..8e8017831e986 100644
--- a/tests/providers/mongo/hooks/test_mongo.py
+++ b/tests/providers/mongo/hooks/test_mongo.py
@@ -58,13 +58,13 @@ def setUp(self):
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_get_conn(self):
- self.assertEqual(self.hook.connection.port, 27017)
- self.assertIsInstance(self.conn, pymongo.MongoClient)
+ assert self.hook.connection.port == 27017
+ assert isinstance(self.conn, pymongo.MongoClient)
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_srv(self):
hook = MongoHook(conn_id='mongo_default_with_srv')
- self.assertTrue(hook.uri.startswith('mongodb+srv://'))
+ assert hook.uri.startswith('mongodb+srv://')
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_insert_one(self):
@@ -74,7 +74,7 @@ def test_insert_one(self):
result_obj = collection.find_one(filter=obj)
- self.assertEqual(obj, result_obj)
+ assert obj == result_obj
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_insert_many(self):
@@ -84,7 +84,7 @@ def test_insert_many(self):
self.hook.insert_many(collection, objs)
result_objs = list(collection.find())
- self.assertEqual(len(result_objs), 2)
+ assert len(result_objs) == 2
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_update_one(self):
@@ -98,7 +98,7 @@ def test_update_one(self):
self.hook.update_one(collection, filter_doc, update_doc)
result_obj = collection.find_one(filter='1')
- self.assertEqual(123, result_obj['field'])
+ assert 123 == result_obj['field']
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_update_one_with_upsert(self):
@@ -110,7 +110,7 @@ def test_update_one_with_upsert(self):
self.hook.update_one(collection, filter_doc, update_doc, upsert=True)
result_obj = collection.find_one(filter='1')
- self.assertEqual(123, result_obj['field'])
+ assert 123 == result_obj['field']
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_update_many(self):
@@ -125,10 +125,10 @@ def test_update_many(self):
self.hook.update_many(collection, filter_doc, update_doc)
result_obj = collection.find_one(filter='1')
- self.assertEqual(123, result_obj['field'])
+ assert 123 == result_obj['field']
result_obj = collection.find_one(filter='2')
- self.assertEqual(123, result_obj['field'])
+ assert 123 == result_obj['field']
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_update_many_with_upsert(self):
@@ -140,7 +140,7 @@ def test_update_many_with_upsert(self):
self.hook.update_many(collection, filter_doc, update_doc, upsert=True)
result_obj = collection.find_one(filter='1')
- self.assertEqual(123, result_obj['field'])
+ assert 123 == result_obj['field']
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_replace_one(self):
@@ -153,11 +153,11 @@ def test_replace_one(self):
self.hook.replace_one(collection, obj1)
result_obj = collection.find_one(filter='1')
- self.assertEqual('test_value_1_updated', result_obj['field'])
+ assert 'test_value_1_updated' == result_obj['field']
# Other document should stay intact
result_obj = collection.find_one(filter='2')
- self.assertEqual('test_value_2', result_obj['field'])
+ assert 'test_value_2' == result_obj['field']
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_replace_one_with_filter(self):
@@ -170,11 +170,11 @@ def test_replace_one_with_filter(self):
self.hook.replace_one(collection, obj1, {'field': 'test_value_1'})
result_obj = collection.find_one(filter='1')
- self.assertEqual('test_value_1_updated', result_obj['field'])
+ assert 'test_value_1_updated' == result_obj['field']
# Other document should stay intact
result_obj = collection.find_one(filter='2')
- self.assertEqual('test_value_2', result_obj['field'])
+ assert 'test_value_2' == result_obj['field']
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_replace_one_with_upsert(self):
@@ -184,7 +184,7 @@ def test_replace_one_with_upsert(self):
self.hook.replace_one(collection, obj, upsert=True)
result_obj = collection.find_one(filter='1')
- self.assertEqual('test_value_1', result_obj['field'])
+ assert 'test_value_1' == result_obj['field']
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_replace_many(self):
@@ -198,10 +198,10 @@ def test_replace_many(self):
self.hook.replace_many(collection, [obj1, obj2])
result_obj = collection.find_one(filter='1')
- self.assertEqual('test_value_1_updated', result_obj['field'])
+ assert 'test_value_1_updated' == result_obj['field']
result_obj = collection.find_one(filter='2')
- self.assertEqual('test_value_2_updated', result_obj['field'])
+ assert 'test_value_2_updated' == result_obj['field']
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_replace_many_with_upsert(self):
@@ -212,10 +212,10 @@ def test_replace_many_with_upsert(self):
self.hook.replace_many(collection, [obj1, obj2], upsert=True)
result_obj = collection.find_one(filter='1')
- self.assertEqual('test_value_1', result_obj['field'])
+ assert 'test_value_1' == result_obj['field']
result_obj = collection.find_one(filter='2')
- self.assertEqual('test_value_2', result_obj['field'])
+ assert 'test_value_2' == result_obj['field']
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_delete_one(self):
@@ -225,7 +225,7 @@ def test_delete_one(self):
self.hook.delete_one(collection, {'_id': '1'})
- self.assertEqual(0, collection.count())
+ assert 0 == collection.count()
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_delete_many(self):
@@ -236,7 +236,7 @@ def test_delete_many(self):
self.hook.delete_many(collection, {'field': 'value'})
- self.assertEqual(0, collection.count())
+ assert 0 == collection.count()
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_find_one(self):
@@ -246,7 +246,7 @@ def test_find_one(self):
result_obj = self.hook.find(collection, {}, find_one=True)
result_obj = {result: result_obj[result] for result in result_obj}
- self.assertEqual(obj, result_obj)
+ assert obj == result_obj
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_find_many(self):
@@ -256,7 +256,7 @@ def test_find_many(self):
result_objs = self.hook.find(collection, {}, find_one=False)
- self.assertGreater(len(list(result_objs)), 1)
+ assert len(list(result_objs)) > 1
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_aggregate(self):
@@ -272,13 +272,13 @@ def test_aggregate(self):
aggregate_query = [{"$match": {'test_status': 'success'}}]
results = self.hook.aggregate(collection, aggregate_query)
- self.assertEqual(len(list(results)), 2)
+ assert len(list(results)) == 2
def test_context_manager(self):
with MongoHook(conn_id='mongo_default', mongo_db='default') as ctx_hook:
ctx_hook.get_conn()
- self.assertIsInstance(ctx_hook, MongoHook)
- self.assertIsNotNone(ctx_hook.client)
+ assert isinstance(ctx_hook, MongoHook)
+ assert ctx_hook.client is not None
- self.assertIsNone(ctx_hook.client)
+ assert ctx_hook.client is None
diff --git a/tests/providers/mongo/sensors/test_mongo.py b/tests/providers/mongo/sensors/test_mongo.py
index 688d19abcf593..6623631550855 100644
--- a/tests/providers/mongo/sensors/test_mongo.py
+++ b/tests/providers/mongo/sensors/test_mongo.py
@@ -52,4 +52,4 @@ def setUp(self):
)
def test_poke(self):
- self.assertTrue(self.sensor.poke(None))
+ assert self.sensor.poke(None)
diff --git a/tests/providers/mysql/hooks/test_mysql.py b/tests/providers/mysql/hooks/test_mysql.py
index 29cbafc45ac4e..538381f61bbe5 100644
--- a/tests/providers/mysql/hooks/test_mysql.py
+++ b/tests/providers/mysql/hooks/test_mysql.py
@@ -56,11 +56,11 @@ def test_get_conn(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['user'], 'login')
- self.assertEqual(kwargs['passwd'], 'password')
- self.assertEqual(kwargs['host'], 'host')
- self.assertEqual(kwargs['db'], 'schema')
+ assert args == ()
+ assert kwargs['user'] == 'login'
+ assert kwargs['passwd'] == 'password'
+ assert kwargs['host'] == 'host'
+ assert kwargs['db'] == 'schema'
@mock.patch('MySQLdb.connect')
def test_get_uri(self, mock_connect):
@@ -68,7 +68,7 @@ def test_get_uri(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(self.db_hook.get_uri(), "mysql://login:password@host/schema?charset=utf-8")
+ assert self.db_hook.get_uri() == "mysql://login:password@host/schema?charset=utf-8"
@mock.patch('MySQLdb.connect')
def test_get_conn_from_connection(self, mock_connect):
@@ -94,8 +94,8 @@ def test_get_conn_port(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['port'], 3307)
+ assert args == ()
+ assert kwargs['port'] == 3307
@mock.patch('MySQLdb.connect')
def test_get_conn_charset(self, mock_connect):
@@ -103,9 +103,9 @@ def test_get_conn_charset(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['charset'], 'utf-8')
- self.assertEqual(kwargs['use_unicode'], True)
+ assert args == ()
+ assert kwargs['charset'] == 'utf-8'
+ assert kwargs['use_unicode'] is True
@mock.patch('MySQLdb.connect')
def test_get_conn_cursor(self, mock_connect):
@@ -113,8 +113,8 @@ def test_get_conn_cursor(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['cursorclass'], MySQLdb.cursors.SSCursor)
+ assert args == ()
+ assert kwargs['cursorclass'] == MySQLdb.cursors.SSCursor
@mock.patch('MySQLdb.connect')
def test_get_conn_local_infile(self, mock_connect):
@@ -122,8 +122,8 @@ def test_get_conn_local_infile(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['local_infile'], 1)
+ assert args == ()
+ assert kwargs['local_infile'] == 1
@mock.patch('MySQLdb.connect')
def test_get_con_unix_socket(self, mock_connect):
@@ -131,8 +131,8 @@ def test_get_con_unix_socket(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['unix_socket'], '/tmp/socket')
+ assert args == ()
+ assert kwargs['unix_socket'] == '/tmp/socket'
@mock.patch('MySQLdb.connect')
def test_get_conn_ssl_as_dictionary(self, mock_connect):
@@ -140,8 +140,8 @@ def test_get_conn_ssl_as_dictionary(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['ssl'], SSL_DICT)
+ assert args == ()
+ assert kwargs['ssl'] == SSL_DICT
@mock.patch('MySQLdb.connect')
def test_get_conn_ssl_as_string(self, mock_connect):
@@ -149,8 +149,8 @@ def test_get_conn_ssl_as_string(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['ssl'], SSL_DICT)
+ assert args == ()
+ assert kwargs['ssl'] == SSL_DICT
@mock.patch('MySQLdb.connect')
@mock.patch('airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_client_type')
@@ -189,11 +189,11 @@ def test_get_conn(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['user'], 'login')
- self.assertEqual(kwargs['password'], 'password')
- self.assertEqual(kwargs['host'], 'host')
- self.assertEqual(kwargs['database'], 'schema')
+ assert args == ()
+ assert kwargs['user'] == 'login'
+ assert kwargs['password'] == 'password'
+ assert kwargs['host'] == 'host'
+ assert kwargs['database'] == 'schema'
@mock.patch('mysql.connector.connect')
def test_get_conn_port(self, mock_connect):
@@ -201,8 +201,8 @@ def test_get_conn_port(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['port'], 3307)
+ assert args == ()
+ assert kwargs['port'] == 3307
@mock.patch('mysql.connector.connect')
def test_get_conn_allow_local_infile(self, mock_connect):
@@ -212,8 +212,8 @@ def test_get_conn_allow_local_infile(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['allow_local_infile'], 1)
+ assert args == ()
+ assert kwargs['allow_local_infile'] == 1
class TestMySqlHook(unittest.TestCase):
@@ -271,9 +271,9 @@ def test_run_multi_queries(self):
self.conn.autocommit.assert_called_once_with(True)
for i in range(len(self.cur.execute.call_args_list)):
args, kwargs = self.cur.execute.call_args_list[i]
- self.assertEqual(len(args), 1)
- self.assertEqual(args[0], sql[i])
- self.assertEqual(kwargs, {})
+ assert len(args) == 1
+ assert args[0] == sql[i]
+ assert kwargs == {}
calls = [mock.call(sql[0]), mock.call(sql[1])]
self.cur.execute.assert_has_calls(calls, any_order=True)
self.conn.commit.assert_not_called()
@@ -297,7 +297,7 @@ def test_bulk_dump(self):
)
def test_serialize_cell(self):
- self.assertEqual('foo', self.db_hook._serialize_cell('foo', None))
+ assert 'foo' == self.db_hook._serialize_cell('foo', None)
def test_bulk_load_custom(self):
self.db_hook.bulk_load_custom(
@@ -387,7 +387,7 @@ def test_mysql_hook_test_bulk_load(self, client):
hook.bulk_load("test_airflow", f.name)
conn.execute("SELECT dummy FROM test_airflow")
results = tuple(result[0] for result in conn.fetchall())
- self.assertEqual(sorted(results), sorted(records))
+ assert sorted(results) == sorted(records)
@parameterized.expand(
[
diff --git a/tests/providers/mysql/transfers/test_s3_to_mysql.py b/tests/providers/mysql/transfers/test_s3_to_mysql.py
index d63bcc5699481..376aaf5043bc3 100644
--- a/tests/providers/mysql/transfers/test_s3_to_mysql.py
+++ b/tests/providers/mysql/transfers/test_s3_to_mysql.py
@@ -18,6 +18,7 @@
import unittest
from unittest.mock import patch
+import pytest
from sqlalchemy import or_
from airflow import configuration, models
@@ -85,7 +86,8 @@ def test_execute(self, mock_remove, mock_bulk_load_custom, mock_download_file):
def test_execute_exception(self, mock_remove, mock_bulk_load_custom, mock_download_file):
mock_bulk_load_custom.side_effect = Exception
- self.assertRaises(Exception, S3ToMySqlOperator(**self.s3_to_mysql_transfer_kwargs).execute, {})
+ with pytest.raises(Exception):
+ S3ToMySqlOperator(**self.s3_to_mysql_transfer_kwargs).execute({})
mock_download_file.assert_called_once_with(key=self.s3_to_mysql_transfer_kwargs['s3_source_key'])
mock_bulk_load_custom.assert_called_once_with(
diff --git a/tests/providers/neo4j/hooks/test_neo4j.py b/tests/providers/neo4j/hooks/test_neo4j.py
index 7f64fc4efbec0..87131c4b5cbc9 100644
--- a/tests/providers/neo4j/hooks/test_neo4j.py
+++ b/tests/providers/neo4j/hooks/test_neo4j.py
@@ -37,7 +37,7 @@ def test_get_uri_neo4j_scheme(self):
self.neo4j_hook.get_connection.return_value = self.connection
uri = self.neo4j_hook.get_uri(self.connection)
- self.assertEqual(uri, "bolt://host:7687")
+ assert uri == "bolt://host:7687"
def test_get_uri_bolt_scheme(self):
@@ -46,7 +46,7 @@ def test_get_uri_bolt_scheme(self):
self.neo4j_hook.get_connection.return_value = self.connection
uri = self.neo4j_hook.get_uri(self.connection)
- self.assertEqual(uri, "bolt://host:7687")
+ assert uri == "bolt://host:7687"
def test_get_uri_bolt_ssc_scheme(self):
self.connection.extra = json.dumps({"certs_self_signed": True, "bolt_scheme": True})
@@ -54,7 +54,7 @@ def test_get_uri_bolt_ssc_scheme(self):
self.neo4j_hook.get_connection.return_value = self.connection
uri = self.neo4j_hook.get_uri(self.connection)
- self.assertEqual(uri, "bolt+ssc://host:7687")
+ assert uri == "bolt+ssc://host:7687"
def test_get_uri_bolt_trusted_ca_scheme(self):
self.connection.extra = json.dumps({"certs_trusted_ca": True, "bolt_scheme": True})
@@ -62,4 +62,4 @@ def test_get_uri_bolt_trusted_ca_scheme(self):
self.neo4j_hook.get_connection.return_value = self.connection
uri = self.neo4j_hook.get_uri(self.connection)
- self.assertEqual(uri, "bolt+s://host:7687")
+ assert uri == "bolt+s://host:7687"
diff --git a/tests/providers/openfaas/hooks/test_openfaas.py b/tests/providers/openfaas/hooks/test_openfaas.py
index b8c40f7fd27ae..baf3704eef093 100644
--- a/tests/providers/openfaas/hooks/test_openfaas.py
+++ b/tests/providers/openfaas/hooks/test_openfaas.py
@@ -20,6 +20,7 @@
import unittest
from unittest import mock
+import pytest
import requests_mock
from airflow.exceptions import AirflowException
@@ -53,7 +54,7 @@ def test_is_function_exist_false(self, mock_get_connection, m):
mock_get_connection.return_value = mock_connection
does_function_exist = self.hook.does_function_exist()
- self.assertFalse(does_function_exist)
+ assert not does_function_exist
@mock.patch.object(BaseHook, 'get_connection')
@requests_mock.mock()
@@ -67,7 +68,7 @@ def test_is_function_exist_true(self, mock_get_connection, m):
mock_get_connection.return_value = mock_connection
does_function_exist = self.hook.does_function_exist()
- self.assertTrue(does_function_exist)
+ assert does_function_exist
@mock.patch.object(BaseHook, 'get_connection')
@requests_mock.mock()
@@ -85,9 +86,9 @@ def test_update_function_false(self, mock_get_connection, m):
mock_connection = Connection(host="http://open-faas.io")
mock_get_connection.return_value = mock_connection
- with self.assertRaises(AirflowException) as context:
+ with pytest.raises(AirflowException) as ctx:
self.hook.update_function({})
- self.assertIn('failed to update ' + FUNCTION_NAME, str(context.exception))
+ assert 'failed to update ' + FUNCTION_NAME in str(ctx.value)
@mock.patch.object(BaseHook, 'get_connection')
@requests_mock.mock()
@@ -100,9 +101,9 @@ def test_invoke_function_false(self, mock_get_connection, m):
mock_connection = Connection(host="http://open-faas.io")
mock_get_connection.return_value = mock_connection
- with self.assertRaises(AirflowException) as context:
+ with pytest.raises(AirflowException) as ctx:
self.hook.invoke_function({})
- self.assertIn('failed to invoke function', str(context.exception))
+ assert 'failed to invoke function' in str(ctx.value)
@mock.patch.object(BaseHook, 'get_connection')
@requests_mock.mock()
@@ -114,7 +115,7 @@ def test_invoke_function_true(self, mock_get_connection, m):
)
mock_connection = Connection(host="http://open-faas.io")
mock_get_connection.return_value = mock_connection
- self.assertEqual(self.hook.invoke_function({}), None)
+ assert self.hook.invoke_function({}) is None
@mock.patch.object(BaseHook, 'get_connection')
@requests_mock.mock()
@@ -127,9 +128,9 @@ def test_invoke_async_function_false(self, mock_get_connection, m):
mock_connection = Connection(host="http://open-faas.io")
mock_get_connection.return_value = mock_connection
- with self.assertRaises(AirflowException) as context:
+ with pytest.raises(AirflowException) as ctx:
self.hook.invoke_async_function({})
- self.assertIn('failed to invoke function', str(context.exception))
+ assert 'failed to invoke function' in str(ctx.value)
@mock.patch.object(BaseHook, 'get_connection')
@requests_mock.mock()
@@ -141,7 +142,7 @@ def test_invoke_async_function_true(self, mock_get_connection, m):
)
mock_connection = Connection(host="http://open-faas.io")
mock_get_connection.return_value = mock_connection
- self.assertEqual(self.hook.invoke_async_function({}), None)
+ assert self.hook.invoke_async_function({}) is None
@mock.patch.object(BaseHook, 'get_connection')
@requests_mock.mock()
@@ -149,7 +150,7 @@ def test_deploy_function_function_already_exist(self, mock_get_connection, m):
m.put("http://open-faas.io/" + self.UPDATE_FUNCTION, json=self.mock_response, status_code=202)
mock_connection = Connection(host="http://open-faas.io/")
mock_get_connection.return_value = mock_connection
- self.assertEqual(self.hook.deploy_function(True, {}), None)
+ assert self.hook.deploy_function(True, {}) is None
@mock.patch.object(BaseHook, 'get_connection')
@requests_mock.mock()
@@ -157,4 +158,4 @@ def test_deploy_function_function_not_exist(self, mock_get_connection, m):
m.post("http://open-faas.io" + self.DEPLOY_FUNCTION, json={}, status_code=202)
mock_connection = Connection(host="http://open-faas.io")
mock_get_connection.return_value = mock_connection
- self.assertEqual(self.hook.deploy_function(False, {}), None)
+ assert self.hook.deploy_function(False, {}) is None
diff --git a/tests/providers/opsgenie/hooks/test_opsgenie_alert.py b/tests/providers/opsgenie/hooks/test_opsgenie_alert.py
index 8a52b4d1e8786..0db9ca4cdbdce 100644
--- a/tests/providers/opsgenie/hooks/test_opsgenie_alert.py
+++ b/tests/providers/opsgenie/hooks/test_opsgenie_alert.py
@@ -19,6 +19,7 @@
import json
import unittest
+import pytest
import requests_mock
from airflow.exceptions import AirflowException
@@ -78,35 +79,33 @@ def setUp(self):
def test_get_api_key(self):
hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id)
api_key = hook._get_api_key()
- self.assertEqual('eb243592-faa2-4ba2-a551q-1afdf565c889', api_key)
+ assert 'eb243592-faa2-4ba2-a551q-1afdf565c889' == api_key
def test_get_conn_defaults_host(self):
hook = OpsgenieAlertHook()
hook.get_conn()
- self.assertEqual('https://api.opsgenie.com', hook.base_url)
+ assert 'https://api.opsgenie.com' == hook.base_url
@requests_mock.mock()
def test_call_with_success(self, m):
hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id)
m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body)
resp = hook.execute(payload=self._payload)
- self.assertEqual(resp.status_code, 202)
- self.assertEqual(resp.json(), self._mock_success_response_body)
+ assert resp.status_code == 202
+ assert resp.json() == self._mock_success_response_body
@requests_mock.mock()
def test_api_key_set(self, m):
hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id)
m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body)
resp = hook.execute(payload=self._payload)
- self.assertEqual(
- resp.request.headers.get('Authorization'), 'GenieKey eb243592-faa2-4ba2-a551q-1afdf565c889'
- )
+ assert resp.request.headers.get('Authorization') == 'GenieKey eb243592-faa2-4ba2-a551q-1afdf565c889'
@requests_mock.mock()
def test_api_key_not_set(self, m):
hook = OpsgenieAlertHook()
m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
hook.execute(payload=self._payload)
@requests_mock.mock()
@@ -114,4 +113,4 @@ def test_payload_set(self, m):
hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id)
m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body)
resp = hook.execute(payload=self._payload)
- self.assertEqual(json.loads(resp.request.body), self._payload)
+ assert json.loads(resp.request.body) == self._payload
diff --git a/tests/providers/opsgenie/operators/test_opsgenie_alert.py b/tests/providers/opsgenie/operators/test_opsgenie_alert.py
index 40faa63035429..321ba756aa452 100644
--- a/tests/providers/opsgenie/operators/test_opsgenie_alert.py
+++ b/tests/providers/opsgenie/operators/test_opsgenie_alert.py
@@ -84,23 +84,23 @@ def test_build_opsgenie_payload(self):
payload = operator._build_opsgenie_payload()
# Then
- self.assertEqual(self.expected_payload_dict, payload)
+ assert self.expected_payload_dict == payload
def test_properties(self):
# Given / When
operator = OpsgenieAlertOperator(task_id='opsgenie_alert_job', dag=self.dag, **self._config)
- self.assertEqual('opsgenie_default', operator.opsgenie_conn_id)
- self.assertEqual(self._config['message'], operator.message)
- self.assertEqual(self._config['alias'], operator.alias)
- self.assertEqual(self._config['description'], operator.description)
- self.assertEqual(self._config['responders'], operator.responders)
- self.assertEqual(self._config['visible_to'], operator.visible_to)
- self.assertEqual(self._config['actions'], operator.actions)
- self.assertEqual(self._config['tags'], operator.tags)
- self.assertEqual(self._config['details'], operator.details)
- self.assertEqual(self._config['entity'], operator.entity)
- self.assertEqual(self._config['source'], operator.source)
- self.assertEqual(self._config['priority'], operator.priority)
- self.assertEqual(self._config['user'], operator.user)
- self.assertEqual(self._config['note'], operator.note)
+ assert 'opsgenie_default' == operator.opsgenie_conn_id
+ assert self._config['message'] == operator.message
+ assert self._config['alias'] == operator.alias
+ assert self._config['description'] == operator.description
+ assert self._config['responders'] == operator.responders
+ assert self._config['visible_to'] == operator.visible_to
+ assert self._config['actions'] == operator.actions
+ assert self._config['tags'] == operator.tags
+ assert self._config['details'] == operator.details
+ assert self._config['entity'] == operator.entity
+ assert self._config['source'] == operator.source
+ assert self._config['priority'] == operator.priority
+ assert self._config['user'] == operator.user
+ assert self._config['note'] == operator.note
diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py
index d27ca44d68704..7e427ddc82335 100644
--- a/tests/providers/oracle/hooks/test_oracle.py
+++ b/tests/providers/oracle/hooks/test_oracle.py
@@ -22,6 +22,7 @@
from unittest import mock
import numpy
+import pytest
from airflow.models import Connection
from airflow.providers.oracle.hooks.oracle import OracleHook
@@ -49,10 +50,10 @@ def test_get_conn_host(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['user'], 'login')
- self.assertEqual(kwargs['password'], 'password')
- self.assertEqual(kwargs['dsn'], 'host')
+ assert args == ()
+ assert kwargs['user'] == 'login'
+ assert kwargs['password'] == 'password'
+ assert kwargs['dsn'] == 'host'
@mock.patch('airflow.providers.oracle.hooks.oracle.cx_Oracle.connect')
def test_get_conn_sid(self, mock_connect):
@@ -61,10 +62,8 @@ def test_get_conn_sid(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(
- kwargs['dsn'], cx_Oracle.makedsn(dsn_sid['dsn'], self.connection.port, dsn_sid['sid'])
- )
+ assert args == ()
+ assert kwargs['dsn'] == cx_Oracle.makedsn(dsn_sid['dsn'], self.connection.port, dsn_sid['sid'])
@mock.patch('airflow.providers.oracle.hooks.oracle.cx_Oracle.connect')
def test_get_conn_service_name(self, mock_connect):
@@ -73,12 +72,9 @@ def test_get_conn_service_name(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(
- kwargs['dsn'],
- cx_Oracle.makedsn(
- dsn_service_name['dsn'], self.connection.port, service_name=dsn_service_name['service_name']
- ),
+ assert args == ()
+ assert kwargs['dsn'] == cx_Oracle.makedsn(
+ dsn_service_name['dsn'], self.connection.port, service_name=dsn_service_name['service_name']
)
@mock.patch('airflow.providers.oracle.hooks.oracle.cx_Oracle.connect')
@@ -87,9 +83,9 @@ def test_get_conn_encoding_without_nencoding(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['encoding'], 'UTF-8')
- self.assertEqual(kwargs['nencoding'], 'UTF-8')
+ assert args == ()
+ assert kwargs['encoding'] == 'UTF-8'
+ assert kwargs['nencoding'] == 'UTF-8'
@mock.patch('airflow.providers.oracle.hooks.oracle.cx_Oracle.connect')
def test_get_conn_encoding_with_nencoding(self, mock_connect):
@@ -97,9 +93,9 @@ def test_get_conn_encoding_with_nencoding(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['encoding'], 'UTF-8')
- self.assertEqual(kwargs['nencoding'], 'gb2312')
+ assert args == ()
+ assert kwargs['encoding'] == 'UTF-8'
+ assert kwargs['nencoding'] == 'gb2312'
@mock.patch('airflow.providers.oracle.hooks.oracle.cx_Oracle.connect')
def test_get_conn_nencoding(self, mock_connect):
@@ -107,9 +103,9 @@ def test_get_conn_nencoding(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertNotIn('encoding', kwargs)
- self.assertEqual(kwargs['nencoding'], 'UTF-8')
+ assert args == ()
+ assert 'encoding' not in kwargs
+ assert kwargs['nencoding'] == 'UTF-8'
@mock.patch('airflow.providers.oracle.hooks.oracle.cx_Oracle.connect')
def test_get_conn_mode(self, mock_connect):
@@ -129,8 +125,8 @@ def test_get_conn_mode(self, mock_connect):
assert mock_connect.call_count == 1
first = False
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['mode'], mode.get(mod))
+ assert args == ()
+ assert kwargs['mode'] == mode.get(mod)
@mock.patch('airflow.providers.oracle.hooks.oracle.cx_Oracle.connect')
def test_get_conn_threaded(self, mock_connect):
@@ -138,8 +134,8 @@ def test_get_conn_threaded(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['threaded'], True)
+ assert args == ()
+ assert kwargs['threaded'] is True
@mock.patch('airflow.providers.oracle.hooks.oracle.cx_Oracle.connect')
def test_get_conn_events(self, mock_connect):
@@ -147,8 +143,8 @@ def test_get_conn_events(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['events'], True)
+ assert args == ()
+ assert kwargs['events'] is True
@mock.patch('airflow.providers.oracle.hooks.oracle.cx_Oracle.connect')
def test_get_conn_purity(self, mock_connect):
@@ -165,8 +161,8 @@ def test_get_conn_purity(self, mock_connect):
assert mock_connect.call_count == 1
first = False
args, kwargs = mock_connect.call_args
- self.assertEqual(args, ())
- self.assertEqual(kwargs['purity'], purity.get(pur))
+ assert args == ()
+ assert kwargs['purity'] == purity.get(pur)
@unittest.skipIf(cx_Oracle is None, 'cx_Oracle package not present')
@@ -281,4 +277,5 @@ def test_bulk_insert_rows_without_fields(self):
def test_bulk_insert_rows_no_rows(self):
rows = []
- self.assertRaises(ValueError, self.db_hook.bulk_insert_rows, 'table', rows)
+ with pytest.raises(ValueError):
+ self.db_hook.bulk_insert_rows('table', rows)
diff --git a/tests/providers/pagerduty/hooks/test_pagerduty.py b/tests/providers/pagerduty/hooks/test_pagerduty.py
index 95e9719cd10bb..1c1348d858362 100644
--- a/tests/providers/pagerduty/hooks/test_pagerduty.py
+++ b/tests/providers/pagerduty/hooks/test_pagerduty.py
@@ -51,16 +51,16 @@ def test_without_routing_key_extra(self, session):
)
session.commit()
hook = PagerdutyHook(pagerduty_conn_id="pagerduty_no_extra")
- self.assertEqual(hook.token, 'pagerduty_token_without_extra', 'token initialised.')
- self.assertEqual(hook.routing_key, None, 'default routing key skipped.')
+ assert hook.token == 'pagerduty_token_without_extra', 'token initialised.'
+ assert hook.routing_key is None, 'default routing key skipped.'
def test_get_token_from_password(self):
hook = PagerdutyHook(pagerduty_conn_id=DEFAULT_CONN_ID)
- self.assertEqual(hook.token, 'pagerduty_token', 'token initialised.')
+ assert hook.token == 'pagerduty_token', 'token initialised.'
def test_token_parameter_override(self):
hook = PagerdutyHook(token="pagerduty_param_token", pagerduty_conn_id=DEFAULT_CONN_ID)
- self.assertEqual(hook.token, 'pagerduty_param_token', 'token initialised.')
+ assert hook.token == 'pagerduty_param_token', 'token initialised.'
@requests_mock.mock()
def test_get_service(self, m):
@@ -76,7 +76,7 @@ def test_get_service(self, m):
m.get('https://api.pagerduty.com/services/PZYX321', json={"service": mock_response_body})
session = hook.get_session()
resp = session.rget('/services/PZYX321')
- self.assertEqual(resp, mock_response_body)
+ assert resp == mock_response_body
@requests_mock.mock()
def test_create_event(self, m):
@@ -93,4 +93,4 @@ def test_create_event(self, m):
source="airflow_test",
severity="error",
)
- self.assertEqual(resp, mock_response_body)
+ assert resp == mock_response_body
diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py
index 07e100fe2848c..752bd8d75ff4c 100644
--- a/tests/providers/postgres/hooks/test_postgres.py
+++ b/tests/providers/postgres/hooks/test_postgres.py
@@ -73,7 +73,7 @@ def test_get_conn_cursor(self, mock_connect):
@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
def test_get_conn_with_invalid_cursor(self, mock_connect):
self.connection.extra = '{"cursor": "mycursor"}'
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
self.db_hook.get_conn()
@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
@@ -164,13 +164,13 @@ def test_copy_expert(self):
self.cur.fetchall.return_value = None
- self.assertEqual(None, self.db_hook.copy_expert(statement, filename))
+ assert self.db_hook.copy_expert(statement, filename) is None
assert self.conn.close.call_count == 1
assert self.cur.close.call_count == 1
assert self.conn.commit.call_count == 1
self.cur.copy_expert.assert_called_once_with(statement, open_mock.return_value)
- self.assertEqual(open_mock.call_args[0], (filename, "r+"))
+ assert open_mock.call_args[0] == (filename, "r+")
@pytest.mark.backend("postgres")
def test_bulk_load(self):
@@ -190,7 +190,7 @@ def test_bulk_load(self):
cur.execute(f"SELECT * FROM {self.table}")
results = [row[0] for row in cur.fetchall()]
- self.assertEqual(sorted(input_data), sorted(results))
+ assert sorted(input_data) == sorted(results)
@pytest.mark.backend("postgres")
def test_bulk_dump(self):
@@ -209,7 +209,7 @@ def test_bulk_dump(self):
f.seek(0)
results = [line.rstrip().decode("utf-8") for line in f.readlines()]
- self.assertEqual(sorted(input_data), sorted(results))
+ assert sorted(input_data) == sorted(results)
@pytest.mark.backend("postgres")
def test_insert_rows(self):
@@ -222,7 +222,7 @@ def test_insert_rows(self):
assert self.cur.close.call_count == 1
commit_count = 2 # The first and last commit
- self.assertEqual(commit_count, self.conn.commit.call_count)
+ assert commit_count == self.conn.commit.call_count
sql = f"INSERT INTO {table} VALUES (%s)"
for row in rows:
@@ -249,7 +249,7 @@ def test_insert_rows_replace(self):
assert self.cur.close.call_count == 1
commit_count = 2 # The first and last commit
- self.assertEqual(commit_count, self.conn.commit.call_count)
+ assert commit_count == self.conn.commit.call_count
sql = (
"INSERT INTO {0} ({1}, {2}) VALUES (%s,%s) "
@@ -303,4 +303,4 @@ def test_rowcount(self):
values = ",".join(f"('{data}')" for data in input_data)
cur.execute(f"INSERT INTO {self.table} VALUES {values}")
conn.commit()
- self.assertEqual(cur.rowcount, len(input_data))
+ assert cur.rowcount == len(input_data)
diff --git a/tests/providers/presto/hooks/test_presto.py b/tests/providers/presto/hooks/test_presto.py
index 02d2a62a88d72..f9e85875fd8de 100644
--- a/tests/providers/presto/hooks/test_presto.py
+++ b/tests/providers/presto/hooks/test_presto.py
@@ -53,7 +53,7 @@ def test_get_conn_basic_auth(self, mock_get_connection, mock_connect, mock_basic
auth=mock_basic_auth.return_value,
)
mock_basic_auth.assert_called_once_with('login', 'password')
- self.assertEqual(mock_connect.return_value, conn)
+ assert mock_connect.return_value == conn
@patch('airflow.providers.presto.hooks.presto.PrestoHook.get_connection')
def test_get_conn_invalid_auth(self, mock_get_connection):
@@ -64,8 +64,8 @@ def test_get_conn_invalid_auth(self, mock_get_connection):
schema='hive',
extra=json.dumps({'auth': 'kerberos'}),
)
- with self.assertRaisesRegex(
- AirflowException, re.escape("Kerberos authorization doesn't support password.")
+ with pytest.raises(
+ AirflowException, match=re.escape("Kerberos authorization doesn't support password.")
):
PrestoHook().get_conn()
@@ -116,7 +116,7 @@ def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect, mock_au
sanitize_mutual_error_response=True,
service_name='TEST_SERVICE_NAME',
)
- self.assertEqual(mock_connect.return_value, conn)
+ assert mock_connect.return_value == conn
@parameterized.expand(
[
@@ -140,7 +140,7 @@ def test_get_conn_verify(self, current_verify, expected_verify):
conn = PrestoHook().get_conn()
mock_verify.assert_called_once_with(expected_verify)
- self.assertEqual(mock_connect.return_value, conn)
+ assert mock_connect.return_value == conn
class TestPrestoHook(unittest.TestCase):
@@ -177,7 +177,7 @@ def test_get_first_record(self):
result_sets = [('row1',), ('row2',)]
self.cur.fetchone.return_value = result_sets[0]
- self.assertEqual(result_sets[0], self.db_hook.get_first(statement))
+ assert result_sets[0] == self.db_hook.get_first(statement)
self.conn.close.assert_called_once_with()
self.cur.close.assert_called_once_with()
self.cur.execute.assert_called_once_with(statement)
@@ -187,7 +187,7 @@ def test_get_records(self):
result_sets = [('row1',), ('row2',)]
self.cur.fetchall.return_value = result_sets
- self.assertEqual(result_sets, self.db_hook.get_records(statement))
+ assert result_sets == self.db_hook.get_records(statement)
self.conn.close.assert_called_once_with()
self.cur.close.assert_called_once_with()
self.cur.execute.assert_called_once_with(statement)
@@ -200,10 +200,10 @@ def test_get_pandas_df(self):
self.cur.fetchall.return_value = result_sets
df = self.db_hook.get_pandas_df(statement)
- self.assertEqual(column, df.columns[0])
+ assert column == df.columns[0]
- self.assertEqual(result_sets[0][0], df.values.tolist()[0][0])
- self.assertEqual(result_sets[1][0], df.values.tolist()[1][0])
+ assert result_sets[0][0] == df.values.tolist()[0][0]
+ assert result_sets[1][0] == df.values.tolist()[1][0]
self.cur.execute.assert_called_once_with(statement, None)
@@ -215,7 +215,7 @@ def test_should_record_records(self):
hook = PrestoHook()
sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3"
records = hook.get_records(sql)
- self.assertEqual([['Customer#000000001'], ['Customer#000000002'], ['Customer#000000003']], records)
+ assert [['Customer#000000001'], ['Customer#000000002'], ['Customer#000000003']] == records
@pytest.mark.integration("presto")
@pytest.mark.integration("kerberos")
@@ -230,6 +230,4 @@ def test_should_record_records_with_kerberos_auth(self):
hook = PrestoHook()
sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3"
records = hook.get_records(sql)
- self.assertEqual(
- [['Customer#000000001'], ['Customer#000000002'], ['Customer#000000003']], records
- )
+ assert [['Customer#000000001'], ['Customer#000000002'], ['Customer#000000003']] == records
diff --git a/tests/providers/qubole/hooks/test_qubole.py b/tests/providers/qubole/hooks/test_qubole.py
index a3718010b58d5..70262cab3c313 100644
--- a/tests/providers/qubole/hooks/test_qubole.py
+++ b/tests/providers/qubole/hooks/test_qubole.py
@@ -27,14 +27,14 @@ class TestQuboleHook(unittest.TestCase):
def test_add_string_to_tags(self):
tags = {'dag_id', 'task_id'}
add_tags(tags, 'string')
- self.assertEqual({'dag_id', 'task_id', 'string'}, tags)
+ assert {'dag_id', 'task_id', 'string'} == tags
def test_add_list_to_tags(self):
tags = {'dag_id', 'task_id'}
add_tags(tags, ['value1', 'value2'])
- self.assertEqual({'dag_id', 'task_id', 'value1', 'value2'}, tags)
+ assert {'dag_id', 'task_id', 'value1', 'value2'} == tags
def test_add_tuple_to_tags(self):
tags = {'dag_id', 'task_id'}
add_tags(tags, ('value1', 'value2'))
- self.assertEqual({'dag_id', 'task_id', 'value1', 'value2'}, tags)
+ assert {'dag_id', 'task_id', 'value1', 'value2'} == tags
diff --git a/tests/providers/qubole/hooks/test_qubole_check.py b/tests/providers/qubole/hooks/test_qubole_check.py
index 4cdb7f5ccb5d8..8bac1e129a1f4 100644
--- a/tests/providers/qubole/hooks/test_qubole_check.py
+++ b/tests/providers/qubole/hooks/test_qubole_check.py
@@ -25,19 +25,19 @@ class TestQuboleCheckHook(unittest.TestCase):
def test_single_row_bool(self):
query_result = ['true\ttrue']
record_list = parse_first_row(query_result)
- self.assertEqual([True, True], record_list)
+ assert [True, True] == record_list
def test_multi_row_bool(self):
query_result = ['true\tfalse', 'true\tfalse']
record_list = parse_first_row(query_result)
- self.assertEqual([True, False], record_list)
+ assert [True, False] == record_list
def test_single_row_float(self):
query_result = ['0.23\t34']
record_list = parse_first_row(query_result)
- self.assertEqual([0.23, 34], record_list)
+ assert [0.23, 34] == record_list
def test_single_row_mixed_types(self):
query_result = ['name\t44\t0.23\tTrue']
record_list = parse_first_row(query_result)
- self.assertEqual(["name", 44, 0.23, True], record_list)
+ assert ["name", 44, 0.23, True] == record_list
diff --git a/tests/providers/qubole/operators/test_qubole.py b/tests/providers/qubole/operators/test_qubole.py
index 480bd84e7a2b4..397666c18b8cd 100644
--- a/tests/providers/qubole/operators/test_qubole.py
+++ b/tests/providers/qubole/operators/test_qubole.py
@@ -49,16 +49,16 @@ def tearDown(self):
def test_init_with_default_connection(self):
op = QuboleOperator(task_id=TASK_ID)
- self.assertEqual(op.task_id, TASK_ID)
- self.assertEqual(op.qubole_conn_id, DEFAULT_CONN)
+ assert op.task_id == TASK_ID
+ assert op.qubole_conn_id == DEFAULT_CONN
def test_init_with_template_connection(self):
with DAG(DAG_ID, start_date=DEFAULT_DATE):
task = QuboleOperator(task_id=TASK_ID, qubole_conn_id="{{ qubole_conn_id }}")
task.render_template_fields({'qubole_conn_id': TEMPLATE_CONN})
- self.assertEqual(task.task_id, TASK_ID)
- self.assertEqual(task.qubole_conn_id, TEMPLATE_CONN)
+ assert task.task_id == TASK_ID
+ assert task.qubole_conn_id == TEMPLATE_CONN
def test_init_with_template_cluster_label(self):
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
@@ -72,7 +72,7 @@ def test_init_with_template_cluster_label(self):
ti = TaskInstance(task, DEFAULT_DATE)
ti.render_templates()
- self.assertEqual(task.cluster_label, 'default')
+ assert task.cluster_label == 'default'
def test_get_hook(self):
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
@@ -81,7 +81,7 @@ def test_get_hook(self):
task = QuboleOperator(task_id=TASK_ID, command_type='hivecmd', dag=dag)
hook = task.get_hook()
- self.assertEqual(hook.__class__, QuboleHook)
+ assert hook.__class__ == QuboleHook
def test_hyphen_args_note_id(self):
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
@@ -89,7 +89,7 @@ def test_hyphen_args_note_id(self):
with dag:
task = QuboleOperator(task_id=TASK_ID, command_type='sparkcmd', note_id="123", dag=dag)
- self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[0], "--note-id=123")
+ assert task.get_hook().create_cmd_args({'run_id': 'dummy'})[0] == "--note-id=123"
def test_notify(self):
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
@@ -97,7 +97,7 @@ def test_notify(self):
with dag:
task = QuboleOperator(task_id=TASK_ID, command_type='sparkcmd', notify=True, dag=dag)
- self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[0], "--notify")
+ assert task.get_hook().create_cmd_args({'run_id': 'dummy'})[0] == "--notify"
def test_position_args_parameters(self):
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
@@ -107,20 +107,18 @@ def test_position_args_parameters(self):
task_id=TASK_ID, command_type='pigcmd', parameters="key1=value1 key2=value2", dag=dag
)
- self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[1], "key1=value1")
- self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[2], "key2=value2")
+ assert task.get_hook().create_cmd_args({'run_id': 'dummy'})[1] == "key1=value1"
+ assert task.get_hook().create_cmd_args({'run_id': 'dummy'})[2] == "key2=value2"
cmd = "s3distcp --src s3n://airflow/source_hadoopcmd --dest s3n://airflow/destination_hadoopcmd"
task = QuboleOperator(task_id=TASK_ID + "_1", command_type='hadoopcmd', dag=dag, sub_command=cmd)
- self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[1], "s3distcp")
- self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[2], "--src")
- self.assertEqual(
- task.get_hook().create_cmd_args({'run_id': 'dummy'})[3], "s3n://airflow/source_hadoopcmd"
- )
- self.assertEqual(task.get_hook().create_cmd_args({'run_id': 'dummy'})[4], "--dest")
- self.assertEqual(
- task.get_hook().create_cmd_args({'run_id': 'dummy'})[5], "s3n://airflow/destination_hadoopcmd"
+ assert task.get_hook().create_cmd_args({'run_id': 'dummy'})[1] == "s3distcp"
+ assert task.get_hook().create_cmd_args({'run_id': 'dummy'})[2] == "--src"
+ assert task.get_hook().create_cmd_args({'run_id': 'dummy'})[3] == "s3n://airflow/source_hadoopcmd"
+ assert task.get_hook().create_cmd_args({'run_id': 'dummy'})[4] == "--dest"
+ assert (
+ task.get_hook().create_cmd_args({'run_id': 'dummy'})[5] == "s3n://airflow/destination_hadoopcmd"
)
def test_get_redirect_url(self):
@@ -140,11 +138,11 @@ def test_get_redirect_url(self):
# check for positive case
url = task.get_extra_links(DEFAULT_DATE, 'Go to QDS')
- self.assertEqual(url, 'http://localhost/v2/analyze?command_id=12345')
+ assert url == 'http://localhost/v2/analyze?command_id=12345'
# check for negative case
url2 = task.get_extra_links(datetime(2017, 1, 2), 'Go to QDS')
- self.assertEqual(url2, '')
+ assert url2 == ''
def test_extra_serialized_field(self):
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
@@ -156,29 +154,29 @@ def test_extra_serialized_field(self):
)
serialized_dag = SerializedDAG.to_dict(dag)
- self.assertIn("qubole_conn_id", serialized_dag["dag"]["tasks"][0])
+ assert "qubole_conn_id" in serialized_dag["dag"]["tasks"][0]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict[TASK_ID]
- self.assertEqual(getattr(simple_task, "qubole_conn_id"), TEST_CONN)
+ assert getattr(simple_task, "qubole_conn_id") == TEST_CONN
#########################################################
# Verify Operator Links work with Serialized Operator
#########################################################
- self.assertIsInstance(list(simple_task.operator_extra_links)[0], QDSLink)
+ assert isinstance(list(simple_task.operator_extra_links)[0], QDSLink)
ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE)
ti.xcom_push('qbol_cmd_id', 12345)
# check for positive case
url = simple_task.get_extra_links(DEFAULT_DATE, 'Go to QDS')
- self.assertEqual(url, 'http://localhost/v2/analyze?command_id=12345')
+ assert url == 'http://localhost/v2/analyze?command_id=12345'
# check for negative case
url2 = simple_task.get_extra_links(datetime(2017, 1, 2), 'Go to QDS')
- self.assertEqual(url2, '')
+ assert url2 == ''
def test_parameter_pool_passed(self):
test_pool = 'test_pool'
op = QuboleOperator(task_id=TASK_ID, pool=test_pool)
- self.assertEqual(op.pool, test_pool)
+ assert op.pool == test_pool
diff --git a/tests/providers/qubole/operators/test_qubole_check.py b/tests/providers/qubole/operators/test_qubole_check.py
index 227f0490a5d35..48383f9bf4ffa 100644
--- a/tests/providers/qubole/operators/test_qubole_check.py
+++ b/tests/providers/qubole/operators/test_qubole_check.py
@@ -20,6 +20,7 @@
from datetime import datetime
from unittest import mock
+import pytest
from qds_sdk.commands import HiveCommand
from airflow.exceptions import AirflowException
@@ -54,8 +55,8 @@ def test_pass_value_template(self):
operator = self.__construct_operator('select date from tab1;', "{{ ds }}")
result = operator.render_template(operator.pass_value, {'ds': pass_value_str})
- self.assertEqual(operator.task_id, self.task_id)
- self.assertEqual(result, pass_value_str)
+ assert operator.task_id == self.task_id
+ assert result == pass_value_str
@mock.patch.object(QuboleValueCheckOperator, 'get_hook')
def test_execute_pass(self, mock_get_hook):
@@ -87,7 +88,7 @@ def test_execute_assertion_fail(self, mock_get_hook):
operator = self.__construct_operator('select value from tab1 limit 1;', 5, 1)
- with self.assertRaisesRegex(AirflowException, 'Qubole Command Id: ' + str(mock_cmd.id)):
+ with pytest.raises(AirflowException, match='Qubole Command Id: ' + str(mock_cmd.id)):
operator.execute()
mock_cmd.is_success.assert_called_once_with(mock_cmd.status)
@@ -107,10 +108,10 @@ def test_execute_assert_query_fail(self, mock_get_hook):
operator = self.__construct_operator('select value from tab1 limit 1;', 5, 1)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
operator.execute()
- self.assertNotIn('Qubole Command Id: ', str(cm.exception))
+ assert 'Qubole Command Id: ' not in str(ctx.value)
mock_cmd.is_success.assert_called_once_with(mock_cmd.status)
@mock.patch.object(QuboleCheckHook, 'get_query_results')
diff --git a/tests/providers/qubole/sensors/test_qubole.py b/tests/providers/qubole/sensors/test_qubole.py
index e7bed5364798d..470af51a4dab2 100644
--- a/tests/providers/qubole/sensors/test_qubole.py
+++ b/tests/providers/qubole/sensors/test_qubole.py
@@ -21,6 +21,8 @@
from datetime import datetime
from unittest.mock import patch
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import DAG, Connection
from airflow.providers.qubole.sensors.qubole import QuboleFileSensor, QubolePartitionSensor
@@ -43,7 +45,7 @@ def test_file_sensore(self, patched_poke):
sensor = QuboleFileSensor(
task_id='test_qubole_file_sensor', data={"files": ["s3://some_bucket/some_file"]}
)
- self.assertTrue(sensor.poke({}))
+ assert sensor.poke({})
@patch('airflow.providers.qubole.sensors.qubole.QubolePartitionSensor.poke')
def test_partition_sensor(self, patched_poke):
@@ -58,7 +60,7 @@ def test_partition_sensor(self, patched_poke):
},
)
- self.assertTrue(sensor.poke({}))
+ assert sensor.poke({})
@patch('airflow.providers.qubole.sensors.qubole.QubolePartitionSensor.poke')
def test_partition_sensor_error(self, patched_poke):
@@ -66,7 +68,7 @@ def test_partition_sensor_error(self, patched_poke):
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
QubolePartitionSensor(
task_id='test_qubole_partition_sensor',
poke_interval=1,
diff --git a/tests/providers/redis/hooks/test_redis.py b/tests/providers/redis/hooks/test_redis.py
index 4328571c86dcb..e4681b3aee771 100644
--- a/tests/providers/redis/hooks/test_redis.py
+++ b/tests/providers/redis/hooks/test_redis.py
@@ -29,13 +29,13 @@
class TestRedisHook(unittest.TestCase):
def test_get_conn(self):
hook = RedisHook(redis_conn_id='redis_default')
- self.assertEqual(hook.redis, None)
+ assert hook.redis is None
- self.assertEqual(hook.host, None, 'host initialised as None.')
- self.assertEqual(hook.port, None, 'port initialised as None.')
- self.assertEqual(hook.password, None, 'password initialised as None.')
- self.assertEqual(hook.db, None, 'db initialised as None.')
- self.assertIs(hook.get_conn(), hook.get_conn(), 'Connection initialized only if None.')
+ assert hook.host is None, 'host initialised as None.'
+ assert hook.port is None, 'port initialised as None.'
+ assert hook.password is None, 'password initialised as None.'
+ assert hook.db is None, 'db initialised as None.'
+ assert hook.get_conn() is hook.get_conn(), 'Connection initialized only if None.'
@mock.patch('airflow.providers.redis.hooks.redis.Redis')
@mock.patch(
@@ -76,20 +76,20 @@ def test_get_conn_with_extra_config(self, mock_get_connection, mock_redis):
def test_get_conn_password_stays_none(self):
hook = RedisHook(redis_conn_id='redis_default')
hook.get_conn()
- self.assertEqual(hook.password, None)
+ assert hook.password is None
@pytest.mark.integration("redis")
def test_real_ping(self):
hook = RedisHook(redis_conn_id='redis_default')
redis = hook.get_conn()
- self.assertTrue(redis.ping(), 'Connection to Redis with PING works.')
+ assert redis.ping(), 'Connection to Redis with PING works.'
@pytest.mark.integration("redis")
def test_real_get_and_set(self):
hook = RedisHook(redis_conn_id='redis_default')
redis = hook.get_conn()
- self.assertTrue(redis.set('test_key', 'test_value'), 'Connection to Redis with SET works.')
- self.assertEqual(redis.get('test_key'), b'test_value', 'Connection to Redis with GET works.')
- self.assertEqual(redis.delete('test_key'), 1, 'Connection to Redis with DELETE works.')
+ assert redis.set('test_key', 'test_value'), 'Connection to Redis with SET works.'
+ assert redis.get('test_key') == b'test_value', 'Connection to Redis with GET works.'
+ assert redis.delete('test_key') == 1, 'Connection to Redis with DELETE works.'
diff --git a/tests/providers/redis/operators/test_redis_publish.py b/tests/providers/redis/operators/test_redis_publish.py
index 9f61214a5df28..152a14fadcd37 100644
--- a/tests/providers/redis/operators/test_redis_publish.py
+++ b/tests/providers/redis/operators/test_redis_publish.py
@@ -55,13 +55,13 @@ def test_execute_hello(self):
operator.execute(self.mock_context)
context_calls = []
- self.assertTrue(self.mock_context['ti'].method_calls == context_calls, "context calls should be same")
+ assert self.mock_context['ti'].method_calls == context_calls, "context calls should be same"
message = pubsub.get_message()
- self.assertEqual(message['type'], 'subscribe')
+ assert message['type'] == 'subscribe'
message = pubsub.get_message()
- self.assertEqual(message['type'], 'message')
- self.assertEqual(message['data'], b'hello')
+ assert message['type'] == 'message'
+ assert message['data'] == b'hello'
pubsub.unsubscribe(self.channel)
diff --git a/tests/providers/redis/sensors/test_redis_key.py b/tests/providers/redis/sensors/test_redis_key.py
index a67582a0ab1f1..c22aa72795994 100644
--- a/tests/providers/redis/sensors/test_redis_key.py
+++ b/tests/providers/redis/sensors/test_redis_key.py
@@ -43,6 +43,6 @@ def test_poke(self):
hook = RedisHook(redis_conn_id='redis_default')
redis = hook.get_conn()
redis.set('test_key', 'test_value')
- self.assertTrue(self.sensor.poke(None), "Key exists on first call.")
+ assert self.sensor.poke(None), "Key exists on first call."
redis.delete('test_key')
- self.assertFalse(self.sensor.poke(None), "Key does NOT exists on second call.")
+ assert not self.sensor.poke(None), "Key does NOT exists on second call."
diff --git a/tests/providers/redis/sensors/test_redis_pub_sub.py b/tests/providers/redis/sensors/test_redis_pub_sub.py
index 207a8268c6184..4a878270fbd5d 100644
--- a/tests/providers/redis/sensors/test_redis_pub_sub.py
+++ b/tests/providers/redis/sensors/test_redis_pub_sub.py
@@ -51,13 +51,13 @@ def test_poke_mock_true(self, mock_redis_conn):
}
result = sensor.poke(self.mock_context)
- self.assertTrue(result)
+ assert result
context_calls = [
call.xcom_push(key='message', value={'type': 'message', 'channel': b'test', 'data': b'd1'})
]
- self.assertTrue(self.mock_context['ti'].method_calls == context_calls, "context call should be same")
+ assert self.mock_context['ti'].method_calls == context_calls, "context call should be same"
@patch('airflow.providers.redis.hooks.redis.RedisHook.get_conn')
def test_poke_mock_false(self, mock_redis_conn):
@@ -72,10 +72,10 @@ def test_poke_mock_false(self, mock_redis_conn):
}
result = sensor.poke(self.mock_context)
- self.assertFalse(result)
+ assert not result
context_calls = []
- self.assertTrue(self.mock_context['ti'].method_calls == context_calls, "context calls should be same")
+ assert self.mock_context['ti'].method_calls == context_calls, "context calls should be same"
@pytest.mark.integration("redis")
def test_poke_true(self):
@@ -88,18 +88,18 @@ def test_poke_true(self):
redis.publish('test', 'message')
result = sensor.poke(self.mock_context)
- self.assertFalse(result)
+ assert not result
result = sensor.poke(self.mock_context)
- self.assertTrue(result)
+ assert result
context_calls = [
call.xcom_push(
key='message',
value={'type': 'message', 'pattern': None, 'channel': b'test', 'data': b'message'},
)
]
- self.assertTrue(self.mock_context['ti'].method_calls == context_calls, "context calls should be same")
+ assert self.mock_context['ti'].method_calls == context_calls, "context calls should be same"
result = sensor.poke(self.mock_context)
- self.assertFalse(result)
+ assert not result
@pytest.mark.integration("redis")
def test_poke_false(self):
@@ -108,8 +108,8 @@ def test_poke_false(self):
)
result = sensor.poke(self.mock_context)
- self.assertFalse(result)
- self.assertTrue(self.mock_context['ti'].method_calls == [], "context calls should be same")
+ assert not result
+ assert self.mock_context['ti'].method_calls == [], "context calls should be same"
result = sensor.poke(self.mock_context)
- self.assertFalse(result)
- self.assertTrue(self.mock_context['ti'].method_calls == [], "context calls should be same")
+ assert not result
+ assert self.mock_context['ti'].method_calls == [], "context calls should be same"
diff --git a/tests/providers/salesforce/hooks/test_salesforce.py b/tests/providers/salesforce/hooks/test_salesforce.py
index fb96227529766..c821057917b13 100644
--- a/tests/providers/salesforce/hooks/test_salesforce.py
+++ b/tests/providers/salesforce/hooks/test_salesforce.py
@@ -21,6 +21,7 @@
from unittest.mock import Mock, patch
import pandas as pd
+import pytest
from numpy import nan
from simple_salesforce import Salesforce
@@ -37,7 +38,7 @@ def test_get_conn_exists(self):
self.salesforce_hook.get_conn()
- self.assertIsNotNone(self.salesforce_hook.conn.return_value)
+ assert self.salesforce_hook.conn.return_value is not None
@patch(
"airflow.providers.salesforce.hooks.salesforce.SalesforceHook.get_connection",
@@ -49,7 +50,7 @@ def test_get_conn_exists(self):
def test_get_conn(self, mock_salesforce, mock_get_connection):
self.salesforce_hook.get_conn()
- self.assertEqual(self.salesforce_hook.conn, mock_salesforce.return_value)
+ assert self.salesforce_hook.conn == mock_salesforce.return_value
mock_salesforce.assert_called_once_with(
username=mock_get_connection.return_value.login,
password=mock_get_connection.return_value.password,
@@ -67,7 +68,7 @@ def test_make_query(self, mock_salesforce):
query_results = self.salesforce_hook.make_query(query, include_deleted=True)
mock_salesforce.return_value.query_all.assert_called_once_with(query, include_deleted=True)
- self.assertEqual(query_results, mock_salesforce.return_value.query_all.return_value)
+ assert query_results == mock_salesforce.return_value.query_all.return_value
@patch("airflow.providers.salesforce.hooks.salesforce.Salesforce")
def test_describe_object(self, mock_salesforce):
@@ -78,7 +79,7 @@ def test_describe_object(self, mock_salesforce):
obj_description = self.salesforce_hook.describe_object(obj)
mock_salesforce.return_value.__getattr__(obj).describe.assert_called_once_with()
- self.assertEqual(obj_description, mock_salesforce.return_value.__getattr__(obj).describe.return_value)
+ assert obj_description == mock_salesforce.return_value.__getattr__(obj).describe.return_value
@patch("airflow.providers.salesforce.hooks.salesforce.SalesforceHook.get_conn")
@patch(
@@ -92,7 +93,7 @@ def test_get_available_fields(self, mock_describe_object, mock_get_conn):
mock_get_conn.assert_called_once_with()
mock_describe_object.assert_called_once_with(obj)
- self.assertEqual(available_fields, ["field_1", "field_2"])
+ assert available_fields == ["field_1", "field_2"]
@patch("airflow.providers.salesforce.hooks.salesforce.SalesforceHook.make_query")
def test_get_object_from_salesforce(self, mock_make_query):
@@ -101,10 +102,10 @@ def test_get_object_from_salesforce(self, mock_make_query):
)
mock_make_query.assert_called_once_with("SELECT field_1,field_2 FROM obj_name")
- self.assertEqual(salesforce_objects, mock_make_query.return_value)
+ assert salesforce_objects == mock_make_query.return_value
def test_write_object_to_file_invalid_format(self):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
self.salesforce_hook.write_object_to_file(query_results=[], filename="test", fmt="test")
@patch(
diff --git a/tests/providers/salesforce/hooks/test_tableau.py b/tests/providers/salesforce/hooks/test_tableau.py
index b416965c1c1fb..130746d43b268 100644
--- a/tests/providers/salesforce/hooks/test_tableau.py
+++ b/tests/providers/salesforce/hooks/test_tableau.py
@@ -80,6 +80,6 @@ def test_get_conn_auth_via_token_and_site_in_init(self, mock_server, mock_tablea
def test_get_all(self, mock_pager, mock_server, mock_tableau_auth): # pylint: disable=unused-argument
with TableauHook(tableau_conn_id='tableau_test_password') as tableau_hook:
jobs = tableau_hook.get_all(resource_name='jobs')
- self.assertEqual(jobs, mock_pager.return_value)
+ assert jobs == mock_pager.return_value
mock_pager.assert_called_once_with(mock_server.return_value.jobs.get)
diff --git a/tests/providers/salesforce/operators/test_tableau_refresh_workbook.py b/tests/providers/salesforce/operators/test_tableau_refresh_workbook.py
index 4751cc9eb10df..77139c19773bc 100644
--- a/tests/providers/salesforce/operators/test_tableau_refresh_workbook.py
+++ b/tests/providers/salesforce/operators/test_tableau_refresh_workbook.py
@@ -18,6 +18,8 @@
import unittest
from unittest.mock import Mock, patch
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.salesforce.operators.tableau_refresh_workbook import TableauRefreshWorkbookOperator
@@ -41,7 +43,7 @@ def test_execute(self, mock_tableau_hook):
job_id = operator.execute(context={})
mock_tableau_hook.server.workbooks.refresh.assert_called_once_with(2)
- self.assertEqual(mock_tableau_hook.server.workbooks.refresh.return_value.id, job_id)
+ assert mock_tableau_hook.server.workbooks.refresh.return_value.id == job_id
@patch('airflow.providers.salesforce.sensors.tableau_job_status.TableauJobStatusSensor')
@patch('airflow.providers.salesforce.operators.tableau_refresh_workbook.TableauHook')
@@ -53,7 +55,7 @@ def test_execute_blocking(self, mock_tableau_hook, mock_tableau_job_status_senso
job_id = operator.execute(context={})
mock_tableau_hook.server.workbooks.refresh.assert_called_once_with(2)
- self.assertEqual(mock_tableau_hook.server.workbooks.refresh.return_value.id, job_id)
+ assert mock_tableau_hook.server.workbooks.refresh.return_value.id == job_id
mock_tableau_job_status_sensor.assert_called_once_with(
job_id=job_id,
site_id=self.kwargs['site_id'],
@@ -68,4 +70,5 @@ def test_execute_missing_workbook(self, mock_tableau_hook):
mock_tableau_hook.return_value.__enter__ = Mock(return_value=mock_tableau_hook)
operator = TableauRefreshWorkbookOperator(workbook_name='test', **self.kwargs)
- self.assertRaises(AirflowException, operator.execute, {})
+ with pytest.raises(AirflowException):
+ operator.execute({})
diff --git a/tests/providers/salesforce/sensors/test_tableau_job_status.py b/tests/providers/salesforce/sensors/test_tableau_job_status.py
index f8b7c3e4d82fd..7f01011befc91 100644
--- a/tests/providers/salesforce/sensors/test_tableau_job_status.py
+++ b/tests/providers/salesforce/sensors/test_tableau_job_status.py
@@ -18,6 +18,7 @@
import unittest
from unittest.mock import Mock, patch
+import pytest
from parameterized import parameterized
from airflow.providers.salesforce.sensors.tableau_job_status import (
@@ -39,7 +40,7 @@ def test_poke(self, mock_tableau_hook):
job_finished = sensor.poke(context={})
- self.assertTrue(job_finished)
+ assert job_finished
mock_get.assert_called_once_with(sensor.job_id)
@parameterized.expand([('1',), ('2',)])
@@ -50,5 +51,6 @@ def test_poke_failed(self, finish_code, mock_tableau_hook):
mock_get.return_value.finish_code = finish_code
sensor = TableauJobStatusSensor(**self.kwargs)
- self.assertRaises(TableauJobFailedException, sensor.poke, {})
+ with pytest.raises(TableauJobFailedException):
+ sensor.poke({})
mock_get.assert_called_once_with(sensor.job_id)
diff --git a/tests/providers/samba/hooks/test_samba.py b/tests/providers/samba/hooks/test_samba.py
index 457a7530b6d53..61b90c7b0b414 100644
--- a/tests/providers/samba/hooks/test_samba.py
+++ b/tests/providers/samba/hooks/test_samba.py
@@ -20,6 +20,7 @@
from unittest import mock
from unittest.mock import call
+import pytest
import smbclient
from airflow.exceptions import AirflowException
@@ -31,7 +32,7 @@
class TestSambaHook(unittest.TestCase):
def test_get_conn_should_fail_if_conn_id_does_not_exist(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
SambaHook('conn')
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
@@ -39,7 +40,7 @@ def test_get_conn(self, get_conn_mock):
get_conn_mock.return_value = connection
hook = SambaHook('samba_default')
- self.assertEqual(smbclient.SambaClient, type(hook.get_conn()))
+ assert isinstance(hook.get_conn(), smbclient.SambaClient)
get_conn_mock.assert_called_once_with('samba_default')
@mock.patch('airflow.providers.samba.hooks.samba.SambaHook.get_conn')
diff --git a/tests/providers/segment/hooks/test_segment.py b/tests/providers/segment/hooks/test_segment.py
index 723a646391ada..66e3d6052018b 100644
--- a/tests/providers/segment/hooks/test_segment.py
+++ b/tests/providers/segment/hooks/test_segment.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.segment.hooks.segment import SegmentHook
@@ -46,10 +48,10 @@ def get_connection(self, _):
def test_get_conn(self):
expected_connection = self.test_hook.get_conn()
- self.assertEqual(expected_connection, self.conn)
- self.assertIsNotNone(expected_connection.write_key)
- self.assertEqual(expected_connection.write_key, self.expected_write_key)
+ assert expected_connection == self.conn
+ assert expected_connection.write_key is not None
+ assert expected_connection.write_key == self.expected_write_key
def test_on_error(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.test_hook.on_error('error', ['items'])
diff --git a/tests/providers/segment/operators/test_segment_track_event.py b/tests/providers/segment/operators/test_segment_track_event.py
index 1272eca9524fc..948e6610e2837 100644
--- a/tests/providers/segment/operators/test_segment_track_event.py
+++ b/tests/providers/segment/operators/test_segment_track_event.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.segment.hooks.segment import SegmentHook
from airflow.providers.segment.operators.segment_track_event import SegmentTrackEventOperator
@@ -47,12 +49,12 @@ def get_connection(self, unused_connection_id):
def test_get_conn(self):
expected_connection = self.test_hook.get_conn()
- self.assertEqual(expected_connection, self.conn)
- self.assertIsNotNone(expected_connection.write_key)
- self.assertEqual(expected_connection.write_key, self.expected_write_key)
+ assert expected_connection == self.conn
+ assert expected_connection.write_key is not None
+ assert expected_connection.write_key == self.expected_write_key
def test_on_error(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self.test_hook.on_error('error', ['items'])
diff --git a/tests/providers/sftp/hooks/test_sftp.py b/tests/providers/sftp/hooks/test_sftp.py
index 9211c30abc254..8af8a0bc8146f 100644
--- a/tests/providers/sftp/hooks/test_sftp.py
+++ b/tests/providers/sftp/hooks/test_sftp.py
@@ -71,30 +71,30 @@ def setUp(self):
def test_get_conn(self):
output = self.hook.get_conn()
- self.assertEqual(type(output), pysftp.Connection)
+ assert isinstance(output, pysftp.Connection)
def test_close_conn(self):
self.hook.conn = self.hook.get_conn()
- self.assertTrue(self.hook.conn is not None)
+ assert self.hook.conn is not None
self.hook.close_conn()
- self.assertTrue(self.hook.conn is None)
+ assert self.hook.conn is None
def test_describe_directory(self):
output = self.hook.describe_directory(TMP_PATH)
- self.assertTrue(TMP_DIR_FOR_TESTS in output)
+ assert TMP_DIR_FOR_TESTS in output
def test_list_directory(self):
output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- self.assertEqual(output, [SUB_DIR])
+ assert output == [SUB_DIR]
def test_create_and_delete_directory(self):
new_dir_name = 'new_dir'
self.hook.create_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- self.assertTrue(new_dir_name in output)
+ assert new_dir_name in output
self.hook.delete_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- self.assertTrue(new_dir_name not in output)
+ assert new_dir_name not in output
def test_create_and_delete_directories(self):
base_dir = "base_dir"
@@ -102,14 +102,14 @@ def test_create_and_delete_directories(self):
new_dir_path = os.path.join(base_dir, sub_dir)
self.hook.create_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- self.assertTrue(base_dir in output)
+ assert base_dir in output
output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
- self.assertTrue(sub_dir in output)
+ assert sub_dir in output
self.hook.delete_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
self.hook.delete_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- self.assertTrue(new_dir_path not in output)
- self.assertTrue(base_dir not in output)
+ assert new_dir_path not in output
+ assert base_dir not in output
def test_store_retrieve_and_delete_file(self):
self.hook.store_file(
@@ -117,17 +117,17 @@ def test_store_retrieve_and_delete_file(self):
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS),
)
output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS])
+ assert output == [SUB_DIR, TMP_FILE_FOR_TESTS]
retrieved_file_name = 'retrieved.txt'
self.hook.retrieve_file(
remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
local_full_path=os.path.join(TMP_PATH, retrieved_file_name),
)
- self.assertTrue(retrieved_file_name in os.listdir(TMP_PATH))
+ assert retrieved_file_name in os.listdir(TMP_PATH)
os.remove(os.path.join(TMP_PATH, retrieved_file_name))
self.hook.delete_file(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- self.assertEqual(output, [SUB_DIR])
+ assert output == [SUB_DIR]
def test_get_mod_time(self):
self.hook.store_file(
@@ -135,14 +135,14 @@ def test_get_mod_time(self):
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS),
)
output = self.hook.get_mod_time(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
- self.assertEqual(len(output), 14)
+ assert len(output) == 14
@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_no_host_key_check_default(self, get_connection):
connection = Connection(login='login', host='host')
get_connection.return_value = connection
hook = SFTPHook()
- self.assertEqual(hook.no_host_key_check, False)
+ assert hook.no_host_key_check is False
@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_no_host_key_check_enabled(self, get_connection):
@@ -150,7 +150,7 @@ def test_no_host_key_check_enabled(self, get_connection):
get_connection.return_value = connection
hook = SFTPHook()
- self.assertEqual(hook.no_host_key_check, True)
+ assert hook.no_host_key_check is True
@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_no_host_key_check_disabled(self, get_connection):
@@ -158,7 +158,7 @@ def test_no_host_key_check_disabled(self, get_connection):
get_connection.return_value = connection
hook = SFTPHook()
- self.assertEqual(hook.no_host_key_check, False)
+ assert hook.no_host_key_check is False
@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_ciphers(self, get_connection):
@@ -166,7 +166,7 @@ def test_ciphers(self, get_connection):
get_connection.return_value = connection
hook = SFTPHook()
- self.assertEqual(hook.ciphers, ["A", "B", "C"])
+ assert hook.ciphers == ["A", "B", "C"]
@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_no_host_key_check_disabled_for_all_but_true(self, get_connection):
@@ -174,7 +174,7 @@ def test_no_host_key_check_disabled_for_all_but_true(self, get_connection):
get_connection.return_value = connection
hook = SFTPHook()
- self.assertEqual(hook.no_host_key_check, False)
+ assert hook.no_host_key_check is False
@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_no_host_key_check_ignore(self, get_connection):
@@ -182,7 +182,7 @@ def test_no_host_key_check_ignore(self, get_connection):
get_connection.return_value = connection
hook = SFTPHook()
- self.assertEqual(hook.no_host_key_check, True)
+ assert hook.no_host_key_check is True
@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_no_host_key_check_no_ignore(self, get_connection):
@@ -190,14 +190,14 @@ def test_no_host_key_check_no_ignore(self, get_connection):
get_connection.return_value = connection
hook = SFTPHook()
- self.assertEqual(hook.no_host_key_check, False)
+ assert hook.no_host_key_check is False
@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_host_key_default(self, get_connection):
connection = Connection(login='login', host='host')
get_connection.return_value = connection
hook = SFTPHook()
- self.assertEqual(hook.host_key, None)
+ assert hook.host_key is None
@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_host_key(self, get_connection):
@@ -208,14 +208,14 @@ def test_host_key(self, get_connection):
)
get_connection.return_value = connection
hook = SFTPHook()
- self.assertEqual(hook.host_key.get_base64(), TEST_HOST_KEY)
+ assert hook.host_key.get_base64() == TEST_HOST_KEY
@mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection')
def test_host_key_with_no_host_key_check(self, get_connection):
connection = Connection(login='login', host='host', extra=json.dumps({"host_key": TEST_HOST_KEY}))
get_connection.return_value = connection
hook = SFTPHook()
- self.assertEqual(hook.host_key, None)
+ assert hook.host_key is None
@parameterized.expand(
[
@@ -227,7 +227,7 @@ def test_host_key_with_no_host_key_check(self, get_connection):
)
def test_path_exists(self, path, exists):
result = self.hook.path_exists(path)
- self.assertEqual(result, exists)
+ assert result == exists
@parameterized.expand(
[
@@ -246,15 +246,15 @@ def test_path_exists(self, path, exists):
)
def test_path_match(self, path, prefix, delimiter, match):
result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter)
- self.assertEqual(result, match)
+ assert result == match
def test_get_tree_map(self):
tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
files, dirs, unknowns = tree_map
- self.assertEqual(files, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)])
- self.assertEqual(dirs, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)])
- self.assertEqual(unknowns, [])
+ assert files == [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)]
+ assert dirs == [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)]
+ assert unknowns == []
def tearDown(self):
shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
diff --git a/tests/providers/sftp/operators/test_sftp.py b/tests/providers/sftp/operators/test_sftp.py
index 927b6d6060c77..aa752b1807f33 100644
--- a/tests/providers/sftp/operators/test_sftp.py
+++ b/tests/providers/sftp/operators/test_sftp.py
@@ -21,6 +21,8 @@
from base64 import b64encode
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.models import DAG, TaskInstance
from airflow.providers.sftp.operators.sftp import SFTPOperation, SFTPOperator
@@ -80,7 +82,7 @@ def test_pickle_file_transfer_put(self):
create_intermediate_dirs=True,
dag=self.dag,
)
- self.assertIsNotNone(put_test_task)
+ assert put_test_task is not None
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
ti2.run()
@@ -92,12 +94,12 @@ def test_pickle_file_transfer_put(self):
do_xcom_push=True,
dag=self.dag,
)
- self.assertIsNotNone(check_file_task)
+ assert check_file_task is not None
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
ti3.run()
- self.assertEqual(
- ti3.xcom_pull(task_ids=check_file_task.task_id, key='return_value').strip(),
- test_local_file_content,
+ assert (
+ ti3.xcom_pull(task_ids=check_file_task.task_id, key='return_value').strip()
+ == test_local_file_content
)
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
@@ -113,7 +115,7 @@ def test_file_transfer_no_intermediate_dir_error_put(self):
# Try to put test file to remote
# This should raise an error with "No such file" as the directory
# does not exist
- with self.assertRaises(Exception) as error:
+ with pytest.raises(Exception) as ctx:
put_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
@@ -123,10 +125,10 @@ def test_file_transfer_no_intermediate_dir_error_put(self):
create_intermediate_dirs=False,
dag=self.dag,
)
- self.assertIsNotNone(put_test_task)
+ assert put_test_task is not None
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
ti2.run()
- self.assertIn('No such file', str(error.exception))
+ assert 'No such file' in str(ctx.value)
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_file_transfer_with_intermediate_dir_put(self):
@@ -148,7 +150,7 @@ def test_file_transfer_with_intermediate_dir_put(self):
create_intermediate_dirs=True,
dag=self.dag,
)
- self.assertIsNotNone(put_test_task)
+ assert put_test_task is not None
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
ti2.run()
@@ -160,11 +162,11 @@ def test_file_transfer_with_intermediate_dir_put(self):
do_xcom_push=True,
dag=self.dag,
)
- self.assertIsNotNone(check_file_task)
+ assert check_file_task is not None
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
ti3.run()
- self.assertEqual(
- ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), test_local_file_content
+ assert (
+ ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip() == test_local_file_content
)
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
@@ -186,7 +188,7 @@ def test_json_file_transfer_put(self):
operation=SFTPOperation.PUT,
dag=self.dag,
)
- self.assertIsNotNone(put_test_task)
+ assert put_test_task is not None
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
ti2.run()
@@ -198,13 +200,12 @@ def test_json_file_transfer_put(self):
do_xcom_push=True,
dag=self.dag,
)
- self.assertIsNotNone(check_file_task)
+ assert check_file_task is not None
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
ti3.run()
- self.assertEqual(
- ti3.xcom_pull(task_ids=check_file_task.task_id, key='return_value').strip(),
- b64encode(test_local_file_content).decode('utf-8'),
- )
+ assert ti3.xcom_pull(task_ids=check_file_task.task_id, key='return_value').strip() == b64encode(
+ test_local_file_content
+ ).decode('utf-8')
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_pickle_file_transfer_get(self):
@@ -221,7 +222,7 @@ def test_pickle_file_transfer_get(self):
do_xcom_push=True,
dag=self.dag,
)
- self.assertIsNotNone(create_file_task)
+ assert create_file_task is not None
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
ti1.run()
@@ -234,7 +235,7 @@ def test_pickle_file_transfer_get(self):
operation=SFTPOperation.GET,
dag=self.dag,
)
- self.assertIsNotNone(get_test_task)
+ assert get_test_task is not None
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
ti2.run()
@@ -242,7 +243,7 @@ def test_pickle_file_transfer_get(self):
content_received = None
with open(self.test_local_filepath) as file:
content_received = file.read()
- self.assertEqual(content_received.strip(), test_remote_file_content)
+ assert content_received.strip() == test_remote_file_content
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
def test_json_file_transfer_get(self):
@@ -259,7 +260,7 @@ def test_json_file_transfer_get(self):
do_xcom_push=True,
dag=self.dag,
)
- self.assertIsNotNone(create_file_task)
+ assert create_file_task is not None
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
ti1.run()
@@ -272,7 +273,7 @@ def test_json_file_transfer_get(self):
operation=SFTPOperation.GET,
dag=self.dag,
)
- self.assertIsNotNone(get_test_task)
+ assert get_test_task is not None
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
ti2.run()
@@ -280,7 +281,7 @@ def test_json_file_transfer_get(self):
content_received = None
with open(self.test_local_filepath) as file:
content_received = file.read()
- self.assertEqual(content_received.strip(), test_remote_file_content.encode('utf-8').decode('utf-8'))
+ assert content_received.strip() == test_remote_file_content.encode('utf-8').decode('utf-8')
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_file_transfer_no_intermediate_dir_error_get(self):
@@ -297,14 +298,14 @@ def test_file_transfer_no_intermediate_dir_error_get(self):
do_xcom_push=True,
dag=self.dag,
)
- self.assertIsNotNone(create_file_task)
+ assert create_file_task is not None
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
ti1.run()
# Try to GET test file from remote
# This should raise an error with "No such file" as the directory
# does not exist
- with self.assertRaises(Exception) as error:
+ with pytest.raises(Exception) as ctx:
get_test_task = SFTPOperator(
task_id="test_sftp",
ssh_hook=self.hook,
@@ -313,10 +314,10 @@ def test_file_transfer_no_intermediate_dir_error_get(self):
operation=SFTPOperation.GET,
dag=self.dag,
)
- self.assertIsNotNone(get_test_task)
+ assert get_test_task is not None
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
ti2.run()
- self.assertIn('No such file', str(error.exception))
+ assert 'No such file' in str(ctx.value)
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_file_transfer_with_intermediate_dir_error_get(self):
@@ -333,7 +334,7 @@ def test_file_transfer_with_intermediate_dir_error_get(self):
do_xcom_push=True,
dag=self.dag,
)
- self.assertIsNotNone(create_file_task)
+ assert create_file_task is not None
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
ti1.run()
@@ -347,7 +348,7 @@ def test_file_transfer_with_intermediate_dir_error_get(self):
create_intermediate_dirs=True,
dag=self.dag,
)
- self.assertIsNotNone(get_test_task)
+ assert get_test_task is not None
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
ti2.run()
@@ -355,12 +356,12 @@ def test_file_transfer_with_intermediate_dir_error_get(self):
content_received = None
with open(self.test_local_filepath_int_dir) as file:
content_received = file.read()
- self.assertEqual(content_received.strip(), test_remote_file_content)
+ assert content_received.strip() == test_remote_file_content
@mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"})
def test_arg_checking(self):
# Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
- with self.assertRaisesRegex(AirflowException, "Cannot operate without ssh_hook or ssh_conn_id."):
+ with pytest.raises(AirflowException, match="Cannot operate without ssh_hook or ssh_conn_id."):
task_0 = SFTPOperator(
task_id="test_sftp_0",
local_filepath=self.test_local_filepath,
@@ -384,7 +385,7 @@ def test_arg_checking(self):
task_1.execute(None)
except Exception: # pylint: disable=broad-except
pass
- self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID)
+ assert task_1.ssh_hook.ssh_conn_id == TEST_CONN_ID
task_2 = SFTPOperator(
task_id="test_sftp_2",
@@ -398,7 +399,7 @@ def test_arg_checking(self):
task_2.execute(None)
except Exception: # pylint: disable=broad-except
pass
- self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID)
+ assert task_2.ssh_hook.ssh_conn_id == TEST_CONN_ID
# if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
task_3 = SFTPOperator(
@@ -414,7 +415,7 @@ def test_arg_checking(self):
task_3.execute(None)
except Exception: # pylint: disable=broad-except
pass
- self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id)
+ assert task_3.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id
def delete_local_resource(self):
if os.path.exists(self.test_local_filepath):
@@ -434,7 +435,7 @@ def delete_remote_resource(self):
do_xcom_push=True,
dag=self.dag,
)
- self.assertIsNotNone(remove_file_task)
+ assert remove_file_task is not None
ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow())
ti3.run()
if os.path.exists(self.test_remote_filepath_int_dir):
diff --git a/tests/providers/sftp/sensors/test_sftp.py b/tests/providers/sftp/sensors/test_sftp.py
index 6115ab956eec0..682e6eb83e97e 100644
--- a/tests/providers/sftp/sensors/test_sftp.py
+++ b/tests/providers/sftp/sensors/test_sftp.py
@@ -19,6 +19,7 @@
import unittest
from unittest.mock import patch
+import pytest
from paramiko import SFTP_FAILURE, SFTP_NO_SUCH_FILE
from airflow.providers.sftp.sensors.sftp import SFTPSensor
@@ -32,7 +33,7 @@ def test_file_present(self, sftp_hook_mock):
context = {'ds': '1970-01-01'}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt')
- self.assertTrue(output)
+ assert output
@patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
def test_file_absent(self, sftp_hook_mock):
@@ -41,17 +42,17 @@ def test_file_absent(self, sftp_hook_mock):
context = {'ds': '1970-01-01'}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt')
- self.assertFalse(output)
+ assert not output
@patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
def test_sftp_failure(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(SFTP_FAILURE, 'SFTP failure')
sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/1970-01-01.txt')
context = {'ds': '1970-01-01'}
- with self.assertRaises(OSError):
+ with pytest.raises(OSError):
sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt')
def test_hook_not_created_during_init(self):
sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/1970-01-01.txt')
- self.assertIsNone(sftp_sensor.hook)
+ assert sftp_sensor.hook is None
diff --git a/tests/providers/singularity/operators/test_singularity.py b/tests/providers/singularity/operators/test_singularity.py
index 71ab5ace64db9..4316bda007154 100644
--- a/tests/providers/singularity/operators/test_singularity.py
+++ b/tests/providers/singularity/operators/test_singularity.py
@@ -19,6 +19,7 @@
import unittest
from unittest import mock
+import pytest
from parameterized import parameterized
from spython.instance import Instance
@@ -48,7 +49,7 @@ def test_execute(self, client_mock):
client_mock.execute.assert_called_once_with(mock.ANY, "echo hello", return_result=True)
execute_args, _ = client_mock.execute.call_args
- self.assertIs(execute_args[0], instance)
+ assert execute_args[0] is instance
instance.start.assert_called_once_with()
instance.stop.assert_called_once_with()
@@ -61,7 +62,7 @@ def test_execute(self, client_mock):
)
def test_command_is_required(self, command):
task = SingularityOperator(task_id='task-id', image="docker://busybox", command=command)
- with self.assertRaisesRegex(AirflowException, "You must define a command."):
+ with pytest.raises(AirflowException, match="You must define a command."):
task.execute({})
@mock.patch('airflow.providers.singularity.operators.singularity.Client')
diff --git a/tests/providers/slack/hooks/test_slack.py b/tests/providers/slack/hooks/test_slack.py
index 6ebec3a38386d..cbe3d26654aca 100644
--- a/tests/providers/slack/hooks/test_slack.py
+++ b/tests/providers/slack/hooks/test_slack.py
@@ -19,6 +19,7 @@
import unittest
from unittest import mock
+import pytest
from slack.errors import SlackApiError
from airflow.exceptions import AirflowException
@@ -38,7 +39,7 @@ def test_get_token_with_token_only(self):
# Assert
output = hook.token
expected = test_token
- self.assertEqual(output, expected)
+ assert output == expected
@mock.patch('airflow.providers.slack.hooks.slack.WebClient')
@mock.patch('airflow.providers.slack.hooks.slack.SlackHook.get_connection')
@@ -58,7 +59,7 @@ def test_get_token_with_valid_slack_conn_id_only(self, get_connection_mock, mock
# Assert
output = hook.token
expected = test_password
- self.assertEqual(output, expected)
+ assert output == expected
mock_slack_client.assert_called_once_with(test_password)
@mock.patch('airflow.providers.slack.hooks.slack.SlackHook.get_connection')
@@ -71,7 +72,8 @@ def test_get_token_with_no_password_slack_conn_id_only(self, get_connection_mock
get_connection_mock.return_value = conn
# Assert
- self.assertRaises(AirflowException, SlackHook, token=None, slack_conn_id='x')
+ with pytest.raises(AirflowException):
+ SlackHook(token=None, slack_conn_id='x')
@mock.patch('airflow.providers.slack.hooks.slack.SlackHook.get_connection')
def test_get_token_with_empty_password_slack_conn_id_only(self, get_connection_mock):
@@ -81,7 +83,8 @@ def test_get_token_with_empty_password_slack_conn_id_only(self, get_connection_m
get_connection_mock.return_value = mock.Mock(password=None)
# Assert
- self.assertRaises(AirflowException, SlackHook, token=None, slack_conn_id='x')
+ with pytest.raises(AirflowException):
+ SlackHook(token=None, slack_conn_id='x')
def test_get_token_with_token_and_slack_conn_id(self):
"""tests `__get_token` method when both arguments are provided """
@@ -95,12 +98,13 @@ def test_get_token_with_token_and_slack_conn_id(self):
# Assert
output = hook.token
expected = test_token
- self.assertEqual(output, expected)
+ assert output == expected
def test_get_token_with_out_token_nor_slack_conn_id(self):
"""tests `__get_token` method when no arguments are provided """
- self.assertRaises(AirflowException, SlackHook, token=None, slack_conn_id=None)
+ with pytest.raises(AirflowException):
+ SlackHook(token=None, slack_conn_id=None)
@mock.patch('airflow.providers.slack.hooks.slack.WebClient')
def test_call_with_failure(self, slack_client_class_mock):
@@ -115,7 +119,7 @@ def test_call_with_failure(self, slack_client_class_mock):
test_method = 'test_method'
test_api_params = {'key1': 'value1', 'key2': 'value2'}
- with self.assertRaises(SlackApiError):
+ with pytest.raises(SlackApiError):
slack_hook.call(test_method, test_api_params)
@mock.patch('airflow.providers.slack.hooks.slack.WebClient.api_call', autospec=True)
diff --git a/tests/providers/slack/hooks/test_slack_webhook.py b/tests/providers/slack/hooks/test_slack_webhook.py
index 30a887fa2e7ba..6fce527a76e23 100644
--- a/tests/providers/slack/hooks/test_slack_webhook.py
+++ b/tests/providers/slack/hooks/test_slack_webhook.py
@@ -93,7 +93,7 @@ def test_get_token_manual_token(self):
webhook_token = hook._get_token(manual_token, None)
# Then
- self.assertEqual(webhook_token, manual_token)
+ assert webhook_token == manual_token
def test_get_token_conn_id(self):
# Given
@@ -105,7 +105,7 @@ def test_get_token_conn_id(self):
webhook_token = hook._get_token(None, conn_id)
# Then
- self.assertEqual(webhook_token, expected_webhook_token)
+ assert webhook_token == expected_webhook_token
def test_get_token_conn_id_password(self):
# Given
@@ -117,7 +117,7 @@ def test_get_token_conn_id_password(self):
webhook_token = hook._get_token(None, conn_id)
# Then
- self.assertEqual(webhook_token, expected_webhook_token)
+ assert webhook_token == expected_webhook_token
def test_build_slack_message(self):
# Given
@@ -127,7 +127,7 @@ def test_build_slack_message(self):
message = hook._build_slack_message()
# Then
- self.assertEqual(self.expected_message_dict, json.loads(message))
+ assert self.expected_message_dict == json.loads(message)
@mock.patch('requests.Session')
@mock.patch('requests.Request')
diff --git a/tests/providers/slack/operators/test_slack.py b/tests/providers/slack/operators/test_slack.py
index e505282c033b3..c0c0383eeb925 100644
--- a/tests/providers/slack/operators/test_slack.py
+++ b/tests/providers/slack/operators/test_slack.py
@@ -94,20 +94,20 @@ def test_init_with_valid_params(self):
test_slack_conn_id = 'test_slack_conn_id'
slack_api_post_operator = self.__construct_operator(test_token, None, self.test_api_params)
- self.assertEqual(slack_api_post_operator.token, test_token)
- self.assertEqual(slack_api_post_operator.slack_conn_id, None)
- self.assertEqual(slack_api_post_operator.method, self.expected_method)
- self.assertEqual(slack_api_post_operator.text, self.test_text)
- self.assertEqual(slack_api_post_operator.channel, self.test_channel)
- self.assertEqual(slack_api_post_operator.api_params, self.test_api_params)
- self.assertEqual(slack_api_post_operator.username, self.test_username)
- self.assertEqual(slack_api_post_operator.icon_url, self.test_icon_url)
- self.assertEqual(slack_api_post_operator.attachments, self.test_attachments)
- self.assertEqual(slack_api_post_operator.blocks, self.test_blocks)
+ assert slack_api_post_operator.token == test_token
+ assert slack_api_post_operator.slack_conn_id is None
+ assert slack_api_post_operator.method == self.expected_method
+ assert slack_api_post_operator.text == self.test_text
+ assert slack_api_post_operator.channel == self.test_channel
+ assert slack_api_post_operator.api_params == self.test_api_params
+ assert slack_api_post_operator.username == self.test_username
+ assert slack_api_post_operator.icon_url == self.test_icon_url
+ assert slack_api_post_operator.attachments == self.test_attachments
+ assert slack_api_post_operator.blocks == self.test_blocks
slack_api_post_operator = self.__construct_operator(None, test_slack_conn_id)
- self.assertEqual(slack_api_post_operator.token, None)
- self.assertEqual(slack_api_post_operator.slack_conn_id, test_slack_conn_id)
+ assert slack_api_post_operator.token is None
+ assert slack_api_post_operator.slack_conn_id == test_slack_conn_id
@mock.patch('airflow.providers.slack.operators.slack.SlackHook')
def test_api_call_params_with_default_args(self, mock_hook):
@@ -132,7 +132,7 @@ def test_api_call_params_with_default_args(self, mock_hook):
'attachments': '[]',
'blocks': '[]',
}
- self.assertEqual(expected_api_params, slack_api_post_operator.api_params)
+ assert expected_api_params == slack_api_post_operator.api_params
class TestSlackAPIFileOperator(unittest.TestCase):
@@ -173,19 +173,19 @@ def test_init_with_valid_params(self):
test_slack_conn_id = 'test_slack_conn_id'
slack_api_post_operator = self.__construct_operator(test_token, None, self.test_api_params)
- self.assertEqual(slack_api_post_operator.token, test_token)
- self.assertEqual(slack_api_post_operator.slack_conn_id, None)
- self.assertEqual(slack_api_post_operator.method, self.expected_method)
- self.assertEqual(slack_api_post_operator.initial_comment, self.test_initial_comment)
- self.assertEqual(slack_api_post_operator.channel, self.test_channel)
- self.assertEqual(slack_api_post_operator.api_params, self.test_api_params)
- self.assertEqual(slack_api_post_operator.filename, self.test_filename)
- self.assertEqual(slack_api_post_operator.filetype, self.test_filetype)
- self.assertEqual(slack_api_post_operator.content, self.test_content)
+ assert slack_api_post_operator.token == test_token
+ assert slack_api_post_operator.slack_conn_id is None
+ assert slack_api_post_operator.method == self.expected_method
+ assert slack_api_post_operator.initial_comment == self.test_initial_comment
+ assert slack_api_post_operator.channel == self.test_channel
+ assert slack_api_post_operator.api_params == self.test_api_params
+ assert slack_api_post_operator.filename == self.test_filename
+ assert slack_api_post_operator.filetype == self.test_filetype
+ assert slack_api_post_operator.content == self.test_content
slack_api_post_operator = self.__construct_operator(None, test_slack_conn_id)
- self.assertEqual(slack_api_post_operator.token, None)
- self.assertEqual(slack_api_post_operator.slack_conn_id, test_slack_conn_id)
+ assert slack_api_post_operator.token is None
+ assert slack_api_post_operator.slack_conn_id == test_slack_conn_id
@mock.patch('airflow.providers.slack.operators.slack.SlackHook')
def test_api_call_params_with_default_args(self, mock_hook):
@@ -205,4 +205,4 @@ def test_api_call_params_with_default_args(self, mock_hook):
'filetype': 'csv',
'content': 'default,content,csv,file',
}
- self.assertEqual(expected_api_params, slack_api_post_operator.api_params)
+ assert expected_api_params == slack_api_post_operator.api_params
diff --git a/tests/providers/slack/operators/test_slack_webhook.py b/tests/providers/slack/operators/test_slack_webhook.py
index 96fd42503e1e0..7b261f4ef4556 100644
--- a/tests/providers/slack/operators/test_slack_webhook.py
+++ b/tests/providers/slack/operators/test_slack_webhook.py
@@ -49,17 +49,17 @@ def test_execute(self):
# Given / When
operator = SlackWebhookOperator(task_id='slack_webhook_job', dag=self.dag, **self._config)
- self.assertEqual(self._config['http_conn_id'], operator.http_conn_id)
- self.assertEqual(self._config['webhook_token'], operator.webhook_token)
- self.assertEqual(self._config['message'], operator.message)
- self.assertEqual(self._config['attachments'], operator.attachments)
- self.assertEqual(self._config['blocks'], operator.blocks)
- self.assertEqual(self._config['channel'], operator.channel)
- self.assertEqual(self._config['username'], operator.username)
- self.assertEqual(self._config['icon_emoji'], operator.icon_emoji)
- self.assertEqual(self._config['icon_url'], operator.icon_url)
- self.assertEqual(self._config['link_names'], operator.link_names)
- self.assertEqual(self._config['proxy'], operator.proxy)
+ assert self._config['http_conn_id'] == operator.http_conn_id
+ assert self._config['webhook_token'] == operator.webhook_token
+ assert self._config['message'] == operator.message
+ assert self._config['attachments'] == operator.attachments
+ assert self._config['blocks'] == operator.blocks
+ assert self._config['channel'] == operator.channel
+ assert self._config['username'] == operator.username
+ assert self._config['icon_emoji'] == operator.icon_emoji
+ assert self._config['icon_url'] == operator.icon_url
+ assert self._config['link_names'] == operator.link_names
+ assert self._config['proxy'] == operator.proxy
def test_assert_templated_fields(self):
operator = SlackWebhookOperator(task_id='slack_webhook_job', dag=self.dag, **self._config)
@@ -74,4 +74,4 @@ def test_assert_templated_fields(self):
'proxy',
]
- self.assertEqual(operator.template_fields, template_fields)
+ assert operator.template_fields == template_fields
diff --git a/tests/providers/snowflake/hooks/test_snowflake.py b/tests/providers/snowflake/hooks/test_snowflake.py
index fd21915b2e236..2ecf73223027f 100644
--- a/tests/providers/snowflake/hooks/test_snowflake.py
+++ b/tests/providers/snowflake/hooks/test_snowflake.py
@@ -87,7 +87,7 @@ def test_get_uri(self):
uri_shouldbe = (
'snowflake://user:pw@airflow/db/public?warehouse=af_wh&role=af_role&authenticator=snowflake'
)
- self.assertEqual(uri_shouldbe, self.db_hook.get_uri())
+ assert uri_shouldbe == self.db_hook.get_uri()
def test_get_conn_params(self):
conn_params_shouldbe = {
@@ -102,11 +102,11 @@ def test_get_conn_params(self):
'authenticator': 'snowflake',
'session_parameters': {"QUERY_TAG": "This is a test hook"},
}
- self.assertEqual(self.db_hook.snowflake_conn_id, 'snowflake_default') # pylint: disable=no-member
- self.assertEqual(conn_params_shouldbe, self.db_hook._get_conn_params())
+ assert self.db_hook.snowflake_conn_id == 'snowflake_default' # pylint: disable=no-member
+ assert conn_params_shouldbe == self.db_hook._get_conn_params()
def test_get_conn(self):
- self.assertEqual(self.db_hook.get_conn(), self.conn)
+ assert self.db_hook.get_conn() == self.conn
def test_key_pair_auth_encrypted(self):
self.conn.extra_dejson = {
@@ -119,7 +119,7 @@ def test_key_pair_auth_encrypted(self):
}
params = self.db_hook._get_conn_params()
- self.assertTrue('private_key' in params)
+ assert 'private_key' in params
def test_key_pair_auth_not_encrypted(self):
self.conn.extra_dejson = {
@@ -133,8 +133,8 @@ def test_key_pair_auth_not_encrypted(self):
self.conn.password = ''
params = self.db_hook._get_conn_params()
- self.assertTrue('private_key' in params)
+ assert 'private_key' in params
self.conn.password = None
params = self.db_hook._get_conn_params()
- self.assertTrue('private_key' in params)
+ assert 'private_key' in params
diff --git a/tests/providers/sqlite/hooks/test_sqlite.py b/tests/providers/sqlite/hooks/test_sqlite.py
index ca7c0ae0d1be0..7a25479373ac8 100644
--- a/tests/providers/sqlite/hooks/test_sqlite.py
+++ b/tests/providers/sqlite/hooks/test_sqlite.py
@@ -72,7 +72,7 @@ def test_get_first_record(self):
result_sets = [('row1',), ('row2',)]
self.cur.fetchone.return_value = result_sets[0]
- self.assertEqual(result_sets[0], self.db_hook.get_first(statement))
+ assert result_sets[0] == self.db_hook.get_first(statement)
self.conn.close.assert_called_once_with()
self.cur.close.assert_called_once_with()
self.cur.execute.assert_called_once_with(statement)
@@ -82,7 +82,7 @@ def test_get_records(self):
result_sets = [('row1',), ('row2',)]
self.cur.fetchall.return_value = result_sets
- self.assertEqual(result_sets, self.db_hook.get_records(statement))
+ assert result_sets == self.db_hook.get_records(statement)
self.conn.close.assert_called_once_with()
self.cur.close.assert_called_once_with()
self.cur.execute.assert_called_once_with(statement)
@@ -95,10 +95,10 @@ def test_get_pandas_df(self):
self.cur.fetchall.return_value = result_sets
df = self.db_hook.get_pandas_df(statement)
- self.assertEqual(column, df.columns[0])
+ assert column == df.columns[0]
- self.assertEqual(result_sets[0][0], df.values.tolist()[0][0])
- self.assertEqual(result_sets[1][0], df.values.tolist()[1][0])
+ assert result_sets[0][0] == df.values.tolist()[0][0]
+ assert result_sets[1][0] == df.values.tolist()[1][0]
self.cur.execute.assert_called_once_with(statement)
diff --git a/tests/providers/ssh/hooks/test_ssh.py b/tests/providers/ssh/hooks/test_ssh.py
index 42a752d1f2153..fea52bc5e01b6 100644
--- a/tests/providers/ssh/hooks/test_ssh.py
+++ b/tests/providers/ssh/hooks/test_ssh.py
@@ -258,14 +258,14 @@ def test_tunnel_without_password(self, ssh_mock):
def test_conn_with_extra_parameters(self):
ssh_hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_EXTRA)
- self.assertEqual(ssh_hook.compress, True)
- self.assertEqual(ssh_hook.no_host_key_check, True)
- self.assertEqual(ssh_hook.allow_host_key_change, False)
- self.assertEqual(ssh_hook.look_for_keys, True)
+ assert ssh_hook.compress is True
+ assert ssh_hook.no_host_key_check is True
+ assert ssh_hook.allow_host_key_change is False
+ assert ssh_hook.look_for_keys is True
def test_conn_with_extra_parameters_false_look_for_keys(self):
ssh_hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_EXTRA_FALSE_LOOK_FOR_KEYS)
- self.assertEqual(ssh_hook.look_for_keys, False)
+ assert ssh_hook.look_for_keys is False
@mock.patch('airflow.providers.ssh.hooks.ssh.SSHTunnelForwarder')
def test_tunnel_with_private_key(self, ssh_mock):
@@ -318,21 +318,21 @@ def test_ssh_connection(self):
with hook.get_conn() as client:
# Note - Pylint will fail with no-member here due to https://github.com/PyCQA/pylint/issues/1437
(_, stdout, _) = client.exec_command('ls') # pylint: disable=no-member
- self.assertIsNotNone(stdout.read())
+ assert stdout.read() is not None
def test_ssh_connection_no_connection_id(self):
hook = SSHHook(remote_host='localhost')
- self.assertIsNone(hook.ssh_conn_id)
+ assert hook.ssh_conn_id is None
with hook.get_conn() as client:
# Note - Pylint will fail with no-member here due to https://github.com/PyCQA/pylint/issues/1437
(_, stdout, _) = client.exec_command('ls') # pylint: disable=no-member
- self.assertIsNotNone(stdout.read())
+ assert stdout.read() is not None
def test_ssh_connection_old_cm(self):
with SSHHook(ssh_conn_id='ssh_default') as hook:
client = hook.get_conn()
(_, stdout, _) = client.exec_command('ls')
- self.assertIsNotNone(stdout.read())
+ assert stdout.read() is not None
def test_tunnel(self):
hook = SSHHook(ssh_conn_id='ssh_default')
@@ -346,14 +346,14 @@ def test_tunnel(self):
)
with subprocess.Popen(**subprocess_kwargs) as server_handle, hook.create_tunnel(2135, 2134):
server_output = server_handle.stdout.read(5)
- self.assertEqual(b"ready", server_output)
+ assert b"ready" == server_output
socket = socket.socket()
socket.connect(("localhost", 2135))
response = socket.recv(5)
- self.assertEqual(response, b"hello")
+ assert response == b"hello"
socket.close()
server_handle.communicate()
- self.assertEqual(server_handle.returncode, 0)
+ assert server_handle.returncode == 0
@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_private_key_extra(self, ssh_mock):
diff --git a/tests/providers/ssh/operators/test_ssh.py b/tests/providers/ssh/operators/test_ssh.py
index 41302b171c350..28b19ea020105 100644
--- a/tests/providers/ssh/operators/test_ssh.py
+++ b/tests/providers/ssh/operators/test_ssh.py
@@ -19,6 +19,7 @@
import unittest.mock
from base64 import b64encode
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@@ -57,12 +58,12 @@ def test_hook_created_correctly(self):
task = SSHOperator(
task_id="test", command=COMMAND, dag=self.dag, timeout=timeout, ssh_conn_id="ssh_default"
)
- self.assertIsNotNone(task)
+ assert task is not None
task.execute(None)
- self.assertEqual(timeout, task.ssh_hook.timeout)
- self.assertEqual(ssh_id, task.ssh_hook.ssh_conn_id)
+ assert timeout == task.ssh_hook.timeout
+ assert ssh_id == task.ssh_hook.ssh_conn_id
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
def test_json_command_execution(self):
@@ -74,14 +75,12 @@ def test_json_command_execution(self):
dag=self.dag,
)
- self.assertIsNotNone(task)
+ assert task is not None
ti = TaskInstance(task=task, execution_date=timezone.utcnow())
ti.run()
- self.assertIsNotNone(ti.duration)
- self.assertEqual(
- ti.xcom_pull(task_ids='test', key='return_value'), b64encode(b'airflow').decode('utf-8')
- )
+ assert ti.duration is not None
+ assert ti.xcom_pull(task_ids='test', key='return_value') == b64encode(b'airflow').decode('utf-8')
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
def test_pickle_command_execution(self):
@@ -93,12 +92,12 @@ def test_pickle_command_execution(self):
dag=self.dag,
)
- self.assertIsNotNone(task)
+ assert task is not None
ti = TaskInstance(task=task, execution_date=timezone.utcnow())
ti.run()
- self.assertIsNotNone(ti.duration)
- self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), b'airflow')
+ assert ti.duration is not None
+ assert ti.xcom_pull(task_ids='test', key='return_value') == b'airflow'
def test_command_execution_with_env(self):
task = SSHOperator(
@@ -110,13 +109,13 @@ def test_command_execution_with_env(self):
environment={'TEST': 'value'},
)
- self.assertIsNotNone(task)
+ assert task is not None
with conf_vars({('core', 'enable_xcom_pickling'): 'True'}):
ti = TaskInstance(task=task, execution_date=timezone.utcnow())
ti.run()
- self.assertIsNotNone(ti.duration)
- self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), b'airflow')
+ assert ti.duration is not None
+ assert ti.xcom_pull(task_ids='test', key='return_value') == b'airflow'
def test_no_output_command(self):
task = SSHOperator(
@@ -127,18 +126,18 @@ def test_no_output_command(self):
dag=self.dag,
)
- self.assertIsNotNone(task)
+ assert task is not None
with conf_vars({('core', 'enable_xcom_pickling'): 'True'}):
ti = TaskInstance(task=task, execution_date=timezone.utcnow())
ti.run()
- self.assertIsNotNone(ti.duration)
- self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), b'')
+ assert ti.duration is not None
+ assert ti.xcom_pull(task_ids='test', key='return_value') == b''
@unittest.mock.patch('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"})
def test_arg_checking(self):
# Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
- with self.assertRaisesRegex(AirflowException, "Cannot operate without ssh_hook or ssh_conn_id."):
+ with pytest.raises(AirflowException, match="Cannot operate without ssh_hook or ssh_conn_id."):
task_0 = SSHOperator(task_id="test", command=COMMAND, timeout=TIMEOUT, dag=self.dag)
task_0.execute(None)
@@ -155,7 +154,7 @@ def test_arg_checking(self):
task_1.execute(None)
except Exception: # pylint: disable=broad-except
pass
- self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID)
+ assert task_1.ssh_hook.ssh_conn_id == TEST_CONN_ID
task_2 = SSHOperator(
task_id="test_2",
@@ -168,7 +167,7 @@ def test_arg_checking(self):
task_2.execute(None)
except Exception: # pylint: disable=broad-except
pass
- self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID)
+ assert task_2.ssh_hook.ssh_conn_id == TEST_CONN_ID
# if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
task_3 = SSHOperator(
@@ -183,7 +182,7 @@ def test_arg_checking(self):
task_3.execute(None)
except Exception: # pylint: disable=broad-except
pass
- self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id)
+ assert task_3.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id
@parameterized.expand(
[
@@ -207,4 +206,4 @@ def test_get_pyt_set_correctly(self, command, get_pty_in, get_pty_out):
task.execute(None)
except Exception: # pylint: disable=broad-except
pass
- self.assertEqual(task.get_pty, get_pty_out)
+ assert task.get_pty == get_pty_out
diff --git a/tests/providers/telegram/hooks/test_telegram.py b/tests/providers/telegram/hooks/test_telegram.py
index 1215be41d199d..2b5e4812ed06c 100644
--- a/tests/providers/telegram/hooks/test_telegram.py
+++ b/tests/providers/telegram/hooks/test_telegram.py
@@ -19,6 +19,7 @@
import unittest
from unittest import mock
+import pytest
import telegram
import airflow
@@ -54,38 +55,38 @@ def setUp(self):
)
def test_should_raise_exception_if_both_connection_or_token_is_not_provided(self):
- with self.assertRaises(airflow.exceptions.AirflowException) as e:
+ with pytest.raises(airflow.exceptions.AirflowException) as ctx:
TelegramHook()
- self.assertEqual("Cannot get token: No valid Telegram connection supplied.", str(e.exception))
+ assert "Cannot get token: No valid Telegram connection supplied." == str(ctx.value)
def test_should_raise_exception_if_conn_id_doesnt_exist(self):
- with self.assertRaises(airflow.exceptions.AirflowNotFoundException) as e:
+ with pytest.raises(airflow.exceptions.AirflowNotFoundException) as ctx:
TelegramHook(telegram_conn_id='telegram-webhook-non-existent')
- self.assertEqual("The conn_id `telegram-webhook-non-existent` isn't defined", str(e.exception))
+ assert "The conn_id `telegram-webhook-non-existent` isn't defined" == str(ctx.value)
def test_should_raise_exception_if_conn_id_doesnt_contain_token(self):
- with self.assertRaises(airflow.exceptions.AirflowException) as e:
+ with pytest.raises(airflow.exceptions.AirflowException) as ctx:
TelegramHook(telegram_conn_id='telegram-webhook-without-token')
- self.assertEqual("Missing token(password) in Telegram connection", str(e.exception))
+ assert "Missing token(password) in Telegram connection" == str(ctx.value)
@mock.patch('airflow.providers.telegram.hooks.telegram.TelegramHook.get_conn')
def test_should_raise_exception_if_chat_id_is_not_provided_anywhere(self, mock_get_conn):
- with self.assertRaises(airflow.exceptions.AirflowException) as e:
+ with pytest.raises(airflow.exceptions.AirflowException) as ctx:
hook = TelegramHook(telegram_conn_id='telegram_default')
hook.send_message({"text": "test telegram message"})
- self.assertEqual("'chat_id' must be provided for telegram message", str(e.exception))
+ assert "'chat_id' must be provided for telegram message" == str(ctx.value)
@mock.patch('airflow.providers.telegram.hooks.telegram.TelegramHook.get_conn')
def test_should_raise_exception_if_message_text_is_not_provided(self, mock_get_conn):
- with self.assertRaises(airflow.exceptions.AirflowException) as e:
+ with pytest.raises(airflow.exceptions.AirflowException) as ctx:
hook = TelegramHook(telegram_conn_id='telegram_default')
hook.send_message({"chat_id": -420913222})
- self.assertEqual("'text' must be provided for telegram message", str(e.exception))
+ assert "'text' must be provided for telegram message" == str(ctx.value)
@mock.patch('airflow.providers.telegram.hooks.telegram.TelegramHook.get_conn')
def test_should_send_message_if_all_parameters_are_correctly_provided(self, mock_get_conn):
@@ -154,12 +155,12 @@ def side_effect(*args, **kwargs):
mock_get_conn.return_value.send_message.side_effect = side_effect
- with self.assertRaises(Exception) as e:
+ with pytest.raises(Exception) as ctx:
hook = TelegramHook(telegram_conn_id='telegram-webhook-with-chat_id')
hook.send_message({"text": "test telegram message"})
- self.assertTrue("RetryError" in str(e.exception))
- self.assertTrue("state=finished raised TelegramError" in str(e.exception))
+ assert "RetryError" in str(ctx.value)
+ assert "state=finished raised TelegramError" in str(ctx.value)
mock_get_conn.assert_called_once()
mock_get_conn.return_value.send_message.assert_called_with(
@@ -170,7 +171,7 @@ def side_effect(*args, **kwargs):
'text': 'test telegram message',
}
)
- self.assertEqual(excepted_retry_count, mock_get_conn.return_value.send_message.call_count)
+ assert excepted_retry_count == mock_get_conn.return_value.send_message.call_count
@mock.patch('airflow.providers.telegram.hooks.telegram.TelegramHook.get_conn')
def test_should_send_message_if_token_is_provided(self, mock_get_conn):
diff --git a/tests/providers/telegram/operators/test_telegram.py b/tests/providers/telegram/operators/test_telegram.py
index b1629f3f3b818..53b43602a6833 100644
--- a/tests/providers/telegram/operators/test_telegram.py
+++ b/tests/providers/telegram/operators/test_telegram.py
@@ -18,6 +18,7 @@
import unittest
from unittest import mock
+import pytest
import telegram
import airflow
@@ -69,10 +70,10 @@ def test_should_send_message_when_all_parameters_are_provided(self, mock_telegra
)
def test_should_throw_exception_if_connection_id_is_none(self):
- with self.assertRaises(airflow.exceptions.AirflowException) as e:
+ with pytest.raises(airflow.exceptions.AirflowException) as ctx:
TelegramOperator(task_id="telegram", telegram_conn_id=None)
- self.assertEqual("No valid Telegram connection id supplied.", str(e.exception))
+ assert "No valid Telegram connection id supplied." == str(ctx.value)
@mock.patch('airflow.providers.telegram.operators.telegram.TelegramHook')
def test_should_throw_exception_if_telegram_hook_throws_any_exception(self, mock_telegram_hook):
@@ -82,7 +83,7 @@ def side_effect(*args, **kwargs):
mock_telegram_hook.return_value = mock.Mock()
mock_telegram_hook.return_value.send_message.side_effect = side_effect
- with self.assertRaises(telegram.error.TelegramError) as e:
+ with pytest.raises(telegram.error.TelegramError) as ctx:
hook = TelegramOperator(
telegram_conn_id='telegram_default',
task_id='telegram',
@@ -90,7 +91,7 @@ def side_effect(*args, **kwargs):
)
hook.execute()
- self.assertEqual("cosmic rays caused bit flips", str(e.exception))
+ assert "cosmic rays caused bit flips" == str(ctx.value)
@mock.patch('airflow.providers.telegram.operators.telegram.TelegramHook')
def test_should_forward_all_args_to_telegram(self, mock_telegram_hook):
@@ -146,4 +147,4 @@ def test_should_return_template_fields(self):
text="some non empty text - higher precedence",
telegram_kwargs={"custom_arg": "value", "text": "some text, that will be ignored"},
)
- self.assertEqual(('text', 'chat_id'), hook.template_fields)
+ assert ('text', 'chat_id') == hook.template_fields
diff --git a/tests/providers/vertica/hooks/test_vertica.py b/tests/providers/vertica/hooks/test_vertica.py
index a953015caa70b..513627fc219f6 100644
--- a/tests/providers/vertica/hooks/test_vertica.py
+++ b/tests/providers/vertica/hooks/test_vertica.py
@@ -82,7 +82,7 @@ def test_get_first_record(self):
result_sets = [('row1',), ('row2',)]
self.cur.fetchone.return_value = result_sets[0]
- self.assertEqual(result_sets[0], self.db_hook.get_first(statement))
+ assert result_sets[0] == self.db_hook.get_first(statement)
self.conn.close.assert_called_once_with()
self.cur.close.assert_called_once_with()
self.cur.execute.assert_called_once_with(statement)
@@ -92,7 +92,7 @@ def test_get_records(self):
result_sets = [('row1',), ('row2',)]
self.cur.fetchall.return_value = result_sets
- self.assertEqual(result_sets, self.db_hook.get_records(statement))
+ assert result_sets == self.db_hook.get_records(statement)
self.conn.close.assert_called_once_with()
self.cur.close.assert_called_once_with()
self.cur.execute.assert_called_once_with(statement)
@@ -105,7 +105,7 @@ def test_get_pandas_df(self):
self.cur.fetchall.return_value = result_sets
df = self.db_hook.get_pandas_df(statement)
- self.assertEqual(column, df.columns[0])
+ assert column == df.columns[0]
- self.assertEqual(result_sets[0][0], df.values.tolist()[0][0])
- self.assertEqual(result_sets[1][0], df.values.tolist()[1][0])
+ assert result_sets[0][0] == df.values.tolist()[0][0]
+ assert result_sets[1][0] == df.values.tolist()[1][0]
diff --git a/tests/providers/yandex/hooks/test_yandex.py b/tests/providers/yandex/hooks/test_yandex.py
index 5229241173612..469c16b9fe6be 100644
--- a/tests/providers/yandex/hooks/test_yandex.py
+++ b/tests/providers/yandex/hooks/test_yandex.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook
@@ -42,7 +44,7 @@ def test_client_created_without_exceptions(self, get_credentials_mock, get_conne
get_credentials_mock.return_value = {"token": 122323}
hook = YandexCloudBaseHook(None, default_folder_id, default_public_ssh_key)
- self.assertIsNotNone(hook.client)
+ assert hook.client is not None
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
def test_get_credentials_raise_exception(self, get_connection_mock):
@@ -60,9 +62,8 @@ def test_get_credentials_raise_exception(self, get_connection_mock):
connection_id='yandexcloud_default', extra_dejson=extra_dejson
)
- self.assertRaises(
- AirflowException, YandexCloudBaseHook, None, default_folder_id, default_public_ssh_key
- )
+ with pytest.raises(AirflowException):
+ YandexCloudBaseHook(None, default_folder_id, default_public_ssh_key)
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
@mock.patch('airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials')
@@ -81,4 +82,4 @@ def test_get_field(self, get_credentials_mock, get_connection_mock):
hook = YandexCloudBaseHook(None, default_folder_id, default_public_ssh_key)
- self.assertEqual(hook._get_field('one'), 'value_one')
+ assert hook._get_field('one') == 'value_one'
diff --git a/tests/providers/yandex/hooks/test_yandexcloud_dataproc.py b/tests/providers/yandex/hooks/test_yandexcloud_dataproc.py
index 851954b8c6fee..fd6defddae95a 100644
--- a/tests/providers/yandex/hooks/test_yandexcloud_dataproc.py
+++ b/tests/providers/yandex/hooks/test_yandexcloud_dataproc.py
@@ -87,13 +87,13 @@ def test_create_dataproc_cluster_mocked(self, create_operation_mock):
cluster_image_version=CLUSTER_IMAGE_VERSION,
service_account_id=SERVICE_ACCOUNT_ID,
)
- self.assertTrue(create_operation_mock.called)
+ assert create_operation_mock.called
@patch('yandexcloud.SDK.create_operation_and_get_result')
def test_delete_dataproc_cluster_mocked(self, create_operation_mock):
self._init_hook()
self.hook.client.delete_cluster('my_cluster_id')
- self.assertTrue(create_operation_mock.called)
+ assert create_operation_mock.called
@patch('yandexcloud.SDK.create_operation_and_get_result')
def test_create_hive_job_hook(self, create_operation_mock):
@@ -107,7 +107,7 @@ def test_create_hive_job_hook(self, create_operation_mock):
query='SELECT 1;',
script_variables=None,
)
- self.assertTrue(create_operation_mock.called)
+ assert create_operation_mock.called
@patch('yandexcloud.SDK.create_operation_and_get_result')
def test_create_mapreduce_job_hook(self, create_operation_mock):
@@ -142,7 +142,7 @@ def test_create_mapreduce_job_hook(self, create_operation_mock):
'mapreduce.job.maps': '6',
},
)
- self.assertTrue(create_operation_mock.called)
+ assert create_operation_mock.called
@patch('yandexcloud.SDK.create_operation_and_get_result')
def test_create_spark_job_hook(self, create_operation_mock):
@@ -167,7 +167,7 @@ def test_create_spark_job_hook(self, create_operation_mock):
name='Spark job',
properties={'spark.submit.deployMode': 'cluster'},
)
- self.assertTrue(create_operation_mock.called)
+ assert create_operation_mock.called
@patch('yandexcloud.SDK.create_operation_and_get_result')
def test_create_pyspark_job_hook(self, create_operation_mock):
@@ -191,4 +191,4 @@ def test_create_pyspark_job_hook(self, create_operation_mock):
properties={'spark.submit.deployMode': 'cluster'},
python_file_uris=['s3a://some-in-bucket/jobs/sources/pyspark-001/geonames.py'],
)
- self.assertTrue(create_operation_mock.called)
+ assert create_operation_mock.called
diff --git a/tests/providers/zendesk/hooks/test_zendesk.py b/tests/providers/zendesk/hooks/test_zendesk.py
index 8accfcfc0a8e3..ba3c273b34b1b 100644
--- a/tests/providers/zendesk/hooks/test_zendesk.py
+++ b/tests/providers/zendesk/hooks/test_zendesk.py
@@ -20,6 +20,7 @@
import unittest
from unittest import mock
+import pytest
from zdesk import RateLimitError
from airflow.providers.zendesk.hooks.zendesk import ZendeskHook
@@ -41,7 +42,7 @@ def test_sleeps_for_correct_interval(self, mocked_time):
zendesk_hook = ZendeskHook("conn_id")
zendesk_hook.get_conn = mock.Mock(return_value=conn_mock)
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
zendesk_hook.call("some_path", get_all_pages=False)
mocked_time.sleep.assert_called_once_with(sleep_time)
diff --git a/tests/secrets/test_local_filesystem.py b/tests/secrets/test_local_filesystem.py
index 5ee4813f8285b..93e83abc7d1e8 100644
--- a/tests/secrets/test_local_filesystem.py
+++ b/tests/secrets/test_local_filesystem.py
@@ -22,6 +22,7 @@
from tempfile import NamedTemporaryFile
from unittest import mock
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException, AirflowFileParseException, ConnectionNotUnique
@@ -46,7 +47,7 @@ class FileParsers(unittest.TestCase):
)
def test_env_file_invalid_format(self, content, expected_message):
with mock_local_file(content):
- with self.assertRaisesRegex(AirflowFileParseException, re.escape(expected_message)):
+ with pytest.raises(AirflowFileParseException, match=re.escape(expected_message)):
local_filesystem.load_variables("a.env")
@parameterized.expand(
@@ -58,7 +59,7 @@ def test_env_file_invalid_format(self, content, expected_message):
)
def test_json_file_invalid_format(self, content, expected_message):
with mock_local_file(content):
- with self.assertRaisesRegex(AirflowFileParseException, re.escape(expected_message)):
+ with pytest.raises(AirflowFileParseException, match=re.escape(expected_message)):
local_filesystem.load_variables("a.json")
@@ -75,12 +76,12 @@ class TestLoadVariables(unittest.TestCase):
def test_env_file_should_load_variables(self, file_content, expected_variables):
with mock_local_file(file_content):
variables = local_filesystem.load_variables("a.env")
- self.assertEqual(expected_variables, variables)
+ assert expected_variables == variables
@parameterized.expand((("AA=A\nAA=B", "The \"a.env\" file contains multiple values for keys: ['AA']"),))
def test_env_file_invalid_logic(self, content, expected_message):
with mock_local_file(content):
- with self.assertRaisesRegex(AirflowException, re.escape(expected_message)):
+ with pytest.raises(AirflowException, match=re.escape(expected_message)):
local_filesystem.load_variables("a.env")
@parameterized.expand(
@@ -94,13 +95,13 @@ def test_env_file_invalid_logic(self, content, expected_message):
def test_json_file_should_load_variables(self, file_content, expected_variables):
with mock_local_file(json.dumps(file_content)):
variables = local_filesystem.load_variables("a.json")
- self.assertEqual(expected_variables, variables)
+ assert expected_variables == variables
@mock.patch("airflow.secrets.local_filesystem.os.path.exists", return_value=False)
def test_missing_file(self, mock_exists):
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowException,
- re.escape("File a.json was not found. Check the configuration of your Secrets backend."),
+ match=re.escape("File a.json was not found. Check the configuration of your Secrets backend."),
):
local_filesystem.load_variables("a.json")
@@ -119,7 +120,7 @@ def test_missing_file(self, mock_exists):
def test_yaml_file_should_load_variables(self, file_content, expected_variables):
with mock_local_file(file_content):
variables = local_filesystem.load_variables('a.yaml')
- self.assertEqual(expected_variables, variables)
+ assert expected_variables == variables
class TestLoadConnection(unittest.TestCase):
@@ -147,7 +148,7 @@ def test_env_file_should_load_connection(self, file_content, expected_connection
conn_id: connection.get_uri() for conn_id, connection in connection_by_conn_id.items()
}
- self.assertEqual(expected_connection_uris, connection_uris_by_conn_id)
+ assert expected_connection_uris == connection_uris_by_conn_id
@parameterized.expand(
(
@@ -157,7 +158,7 @@ def test_env_file_should_load_connection(self, file_content, expected_connection
)
def test_env_file_invalid_format(self, content, expected_message):
with mock_local_file(content):
- with self.assertRaisesRegex(AirflowFileParseException, re.escape(expected_message)):
+ with pytest.raises(AirflowFileParseException, match=re.escape(expected_message)):
local_filesystem.load_connections_dict("a.env")
@parameterized.expand(
@@ -175,7 +176,7 @@ def test_json_file_should_load_connection(self, file_content, expected_connectio
conn_id: connection.get_uri() for conn_id, connection in connections_by_conn_id.items()
}
- self.assertEqual(expected_connection_uris, connection_uris_by_conn_id)
+ assert expected_connection_uris == connection_uris_by_conn_id
@parameterized.expand(
(
@@ -190,14 +191,14 @@ def test_json_file_should_load_connection(self, file_content, expected_connectio
)
def test_env_file_invalid_input(self, file_content, expected_connection_uris):
with mock_local_file(json.dumps(file_content)):
- with self.assertRaisesRegex(AirflowException, re.escape(expected_connection_uris)):
+ with pytest.raises(AirflowException, match=re.escape(expected_connection_uris)):
local_filesystem.load_connections_dict("a.json")
@mock.patch("airflow.secrets.local_filesystem.os.path.exists", return_value=False)
def test_missing_file(self, mock_exists):
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowException,
- re.escape("File a.json was not found. Check the configuration of your Secrets backend."),
+ match=re.escape("File a.json was not found. Check the configuration of your Secrets backend."),
):
local_filesystem.load_connections_dict("a.json")
@@ -236,7 +237,7 @@ def test_yaml_file_should_load_connection(self, file_content, expected_connectio
conn_id: connection.get_uri() for conn_id, connection in connections_by_conn_id.items()
}
- self.assertEqual(expected_connection_uris, connection_uris_by_conn_id)
+ assert expected_connection_uris == connection_uris_by_conn_id
@parameterized.expand(
(
@@ -297,7 +298,7 @@ def test_yaml_file_should_load_connection_extras(self, file_content, expected_ex
connection_uris_by_conn_id = {
conn_id: connection.extra_dejson for conn_id, connection in connections_by_conn_id.items()
}
- self.assertEqual(expected_extras, connection_uris_by_conn_id)
+ assert expected_extras == connection_uris_by_conn_id
@parameterized.expand(
(
@@ -321,7 +322,7 @@ def test_yaml_file_should_load_connection_extras(self, file_content, expected_ex
)
def test_yaml_invalid_extra(self, file_content, expected_message):
with mock_local_file(file_content):
- with self.assertRaisesRegex(AirflowException, re.escape(expected_message)):
+ with pytest.raises(AirflowException, match=re.escape(expected_message)):
local_filesystem.load_connections_dict("a.yaml")
@parameterized.expand(
@@ -329,7 +330,7 @@ def test_yaml_invalid_extra(self, file_content, expected_message):
)
def test_ensure_unique_connection_env(self, file_content):
with mock_local_file(file_content):
- with self.assertRaises(ConnectionNotUnique):
+ with pytest.raises(ConnectionNotUnique):
local_filesystem.load_connections_dict("a.env")
@parameterized.expand(
@@ -340,7 +341,7 @@ def test_ensure_unique_connection_env(self, file_content):
)
def test_ensure_unique_connection_json(self, file_content):
with mock_local_file(json.dumps(file_content)):
- with self.assertRaises(ConnectionNotUnique):
+ with pytest.raises(ConnectionNotUnique):
local_filesystem.load_connections_dict("a.json")
@parameterized.expand(
@@ -355,7 +356,7 @@ def test_ensure_unique_connection_json(self, file_content):
)
def test_ensure_unique_connection_yaml(self, file_content):
with mock_local_file(file_content):
- with self.assertRaises(ConnectionNotUnique):
+ with pytest.raises(ConnectionNotUnique):
local_filesystem.load_connections_dict("a.yaml")
@@ -365,21 +366,18 @@ def test_should_read_variable(self):
tmp_file.write(b"KEY_A=VAL_A")
tmp_file.flush()
backend = LocalFilesystemBackend(variables_file_path=tmp_file.name)
- self.assertEqual("VAL_A", backend.get_variable("KEY_A"))
- self.assertIsNone(backend.get_variable("KEY_B"))
+ assert "VAL_A" == backend.get_variable("KEY_A")
+ assert backend.get_variable("KEY_B") is None
def test_should_read_connection(self):
with NamedTemporaryFile(suffix=".env") as tmp_file:
tmp_file.write(b"CONN_A=mysql://host_a")
tmp_file.flush()
backend = LocalFilesystemBackend(connections_file_path=tmp_file.name)
- self.assertEqual(
- ["mysql://host_a"],
- [conn.get_uri() for conn in backend.get_connections("CONN_A")],
- )
- self.assertIsNone(backend.get_variable("CONN_B"))
+ assert ["mysql://host_a"] == [conn.get_uri() for conn in backend.get_connections("CONN_A")]
+ assert backend.get_variable("CONN_B") is None
def test_files_are_optional(self):
backend = LocalFilesystemBackend()
- self.assertEqual([], backend.get_connections("CONN_A"))
- self.assertIsNone(backend.get_variable("VAR_A"))
+ assert [] == backend.get_connections("CONN_A")
+ assert backend.get_variable("VAR_A") is None
diff --git a/tests/secrets/test_secrets.py b/tests/secrets/test_secrets.py
index fe14a1bce8b22..31baf3808e6f3 100644
--- a/tests/secrets/test_secrets.py
+++ b/tests/secrets/test_secrets.py
@@ -55,8 +55,8 @@ def test_initialize_secrets_backends(self):
backends = initialize_secrets_backends()
backend_classes = [backend.__class__.__name__ for backend in backends]
- self.assertEqual(3, len(backends))
- self.assertIn('SystemsManagerParameterStoreBackend', backend_classes)
+ assert 3 == len(backends)
+ assert 'SystemsManagerParameterStoreBackend' in backend_classes
@conf_vars(
{
@@ -75,7 +75,7 @@ def test_backends_kwargs(self):
if backend.__class__.__name__ == 'SystemsManagerParameterStoreBackend'
][0]
- self.assertEqual(systems_manager.kwargs, {'use_ssl': False})
+ assert systems_manager.kwargs == {'use_ssl': False}
@conf_vars(
{
@@ -101,14 +101,14 @@ def test_backend_fallback_to_env_var(self, mock_get_uri):
backends = ensure_secrets_loaded()
backend_classes = [backend.__class__.__name__ for backend in backends]
- self.assertIn('SystemsManagerParameterStoreBackend', backend_classes)
+ assert 'SystemsManagerParameterStoreBackend' in backend_classes
conn = Connection.get_connection_from_secrets(conn_id="test_mysql")
# Assert that SystemsManagerParameterStoreBackend.get_conn_uri was called
mock_get_uri.assert_called_once_with(conn_id='test_mysql')
- self.assertEqual('mysql://airflow:airflow@host:5432/airflow', conn.get_uri())
+ assert 'mysql://airflow:airflow@host:5432/airflow' == conn.get_uri()
class TestVariableFromSecrets(unittest.TestCase):
@@ -148,4 +148,4 @@ def test_backend_fallback_to_default_var(self):
the value returned is default_var
"""
variable_value = Variable.get(key="test_var", default_var="new")
- self.assertEqual("new", variable_value)
+ assert "new" == variable_value
diff --git a/tests/secrets/test_secrets_backends.py b/tests/secrets/test_secrets_backends.py
index 5318b8b1212d9..7e2eba88c550d 100644
--- a/tests/secrets/test_secrets_backends.py
+++ b/tests/secrets/test_secrets_backends.py
@@ -56,18 +56,18 @@ def tearDown(self) -> None:
)
def test_build_path(self, _, kwargs, output):
build_path = BaseSecretsBackend.build_path
- self.assertEqual(build_path(**kwargs), output)
+ assert build_path(**kwargs) == output
def test_connection_env_secrets_backend(self):
sample_conn_1 = SampleConn("sample_1", "A")
env_secrets_backend = EnvironmentVariablesBackend()
os.environ[sample_conn_1.var_name] = sample_conn_1.conn_uri
conn_list = env_secrets_backend.get_connections(sample_conn_1.conn_id)
- self.assertEqual(1, len(conn_list))
+ assert 1 == len(conn_list)
conn = conn_list[0]
# we could make this more precise by defining __eq__ method for Connection
- self.assertEqual(sample_conn_1.host.lower(), conn.host)
+ assert sample_conn_1.host.lower() == conn.host
def test_connection_metastore_secrets_backend(self):
sample_conn_2 = SampleConn("sample_2", "A")
@@ -77,7 +77,7 @@ def test_connection_metastore_secrets_backend(self):
metastore_backend = MetastoreBackend()
conn_list = metastore_backend.get_connections("sample_2")
host_list = {x.host for x in conn_list}
- self.assertEqual({sample_conn_2.host.lower()}, set(host_list))
+ assert {sample_conn_2.host.lower()} == set(host_list)
@mock.patch.dict(
'os.environ',
@@ -89,15 +89,15 @@ def test_connection_metastore_secrets_backend(self):
def test_variable_env_secrets_backend(self):
env_secrets_backend = EnvironmentVariablesBackend()
variable_value = env_secrets_backend.get_variable(key="hello")
- self.assertEqual('World', variable_value)
- self.assertIsNone(env_secrets_backend.get_variable(key="non_existent_key"))
- self.assertEqual('', env_secrets_backend.get_variable(key="empty_str"))
+ assert 'World' == variable_value
+ assert env_secrets_backend.get_variable(key="non_existent_key") is None
+ assert '' == env_secrets_backend.get_variable(key="empty_str")
def test_variable_metastore_secrets_backend(self):
Variable.set(key="hello", value="World")
Variable.set(key="empty_str", value="")
metastore_backend = MetastoreBackend()
variable_value = metastore_backend.get_variable(key="hello")
- self.assertEqual("World", variable_value)
- self.assertIsNone(metastore_backend.get_variable(key="non_existent_key"))
- self.assertEqual('', metastore_backend.get_variable(key="empty_str"))
+ assert "World" == variable_value
+ assert metastore_backend.get_variable(key="non_existent_key") is None
+ assert '' == metastore_backend.get_variable(key="empty_str")
diff --git a/tests/security/test_kerberos.py b/tests/security/test_kerberos.py
index 3f08540f38491..674ff16769463 100644
--- a/tests/security/test_kerberos.py
+++ b/tests/security/test_kerberos.py
@@ -20,6 +20,8 @@
import unittest
from argparse import Namespace
+import pytest
+
from airflow.security import kerberos
from airflow.security.kerberos import renew_from_kt
from tests.test_utils.config import conf_vars
@@ -32,16 +34,16 @@ class TestKerberos(unittest.TestCase):
def setUp(self):
self.args = Namespace(
keytab=KRB5_KTNAME, principal=None, pid=None, daemon=None, stdout=None, stderr=None, log_file=None
- )
+ ) # pylint: disable=no-member
@conf_vars({('kerberos', 'keytab'): KRB5_KTNAME})
def test_renew_from_kt(self):
"""
We expect no result, but a successful run. No more TypeError
"""
- self.assertIsNone(
- renew_from_kt(principal=self.args.principal, keytab=self.args.keytab) # pylint: disable=no-member
- )
+ assert (
+ renew_from_kt(principal=self.args.principal, keytab=self.args.keytab) is None
+ ) # pylint: disable=no-member
@conf_vars({('kerberos', 'keytab'): ''})
def test_args_from_cli(self):
@@ -50,15 +52,14 @@ def test_args_from_cli(self):
"""
self.args.keytab = "test_keytab"
- with self.assertRaises(SystemExit) as err:
+ with pytest.raises(SystemExit) as ctx:
renew_from_kt(principal=self.args.principal, keytab=self.args.keytab) # pylint: disable=no-member
with self.assertLogs(kerberos.log) as log:
- self.assertIn(
+ assert (
'kinit: krb5_init_creds_set_keytab: Failed to find '
'airflow@LUPUS.GRIDDYNAMICS.NET in keytab FILE:{} '
- '(unknown enctype)'.format(self.args.keytab),
- log.output,
+ '(unknown enctype)'.format(self.args.keytab) in log.output
)
- self.assertEqual(err.exception.code, 1)
+ assert ctx.value.code == 1
diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py
index 61e66f2290f56..d47701183e2a3 100644
--- a/tests/sensors/test_base.py
+++ b/tests/sensors/test_base.py
@@ -21,6 +21,7 @@
from time import sleep
from unittest.mock import Mock, patch
+import pytest
from freezegun import freeze_time
from airflow.exceptions import AirflowException, AirflowRescheduleException, AirflowSensorTimeout
@@ -99,26 +100,26 @@ def test_ok(self):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
def test_fail(self):
sensor = self._make_sensor(False)
dr = self._make_dag_run()
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.FAILED)
+ assert ti.state == State.FAILED
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
def test_soft_fail(self):
sensor = self._make_sensor(False, soft_fail=True)
@@ -126,12 +127,12 @@ def test_soft_fail(self):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
def test_soft_fail_with_retries(self):
sensor = self._make_sensor(
@@ -140,26 +141,26 @@ def test_soft_fail_with_retries(self):
dr = self._make_dag_run()
# first run fails and task instance is marked up to retry
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.UP_FOR_RETRY)
+ assert ti.state == State.UP_FOR_RETRY
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
sleep(0.001)
# after retry DAG run is skipped
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
def test_ok_with_reschedule(self):
sensor = self._make_sensor(return_value=None, poke_interval=10, timeout=25, mode='reschedule')
@@ -171,58 +172,54 @@ def test_ok_with_reschedule(self):
with freeze_time(date1):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
# verify task is re-scheduled, i.e. state set to NONE
- self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
+ assert ti.state == State.UP_FOR_RESCHEDULE
# verify task start date is the initial one
- self.assertEqual(ti.start_date, date1)
+ assert ti.start_date == date1
# verify one row in task_reschedule table
task_reschedules = TaskReschedule.find_for_task_instance(ti)
- self.assertEqual(len(task_reschedules), 1)
- self.assertEqual(task_reschedules[0].start_date, date1)
- self.assertEqual(
- task_reschedules[0].reschedule_date, date1 + timedelta(seconds=sensor.poke_interval)
- )
+ assert len(task_reschedules) == 1
+ assert task_reschedules[0].start_date == date1
+ assert task_reschedules[0].reschedule_date == date1 + timedelta(seconds=sensor.poke_interval)
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
# second poke returns False and task is re-scheduled
date2 = date1 + timedelta(seconds=sensor.poke_interval)
with freeze_time(date2):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
# verify task is re-scheduled, i.e. state set to NONE
- self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
+ assert ti.state == State.UP_FOR_RESCHEDULE
# verify task start date is the initial one
- self.assertEqual(ti.start_date, date1)
+ assert ti.start_date == date1
# verify two rows in task_reschedule table
task_reschedules = TaskReschedule.find_for_task_instance(ti)
- self.assertEqual(len(task_reschedules), 2)
- self.assertEqual(task_reschedules[1].start_date, date2)
- self.assertEqual(
- task_reschedules[1].reschedule_date, date2 + timedelta(seconds=sensor.poke_interval)
- )
+ assert len(task_reschedules) == 2
+ assert task_reschedules[1].start_date == date2
+ assert task_reschedules[1].reschedule_date == date2 + timedelta(seconds=sensor.poke_interval)
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
# third poke returns True and task succeeds
date3 = date2 + timedelta(seconds=sensor.poke_interval)
with freeze_time(date3):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
# verify task start date is the initial one
- self.assertEqual(ti.start_date, date1)
+ assert ti.start_date == date1
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
def test_fail_with_reschedule(self):
sensor = self._make_sensor(return_value=False, poke_interval=10, timeout=5, mode='reschedule')
@@ -233,25 +230,25 @@ def test_fail_with_reschedule(self):
with freeze_time(date1):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
+ assert ti.state == State.UP_FOR_RESCHEDULE
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
# second poke returns False, timeout occurs
date2 = date1 + timedelta(seconds=sensor.poke_interval)
with freeze_time(date2):
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.FAILED)
+ assert ti.state == State.FAILED
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
def test_soft_fail_with_reschedule(self):
sensor = self._make_sensor(
@@ -264,24 +261,24 @@ def test_soft_fail_with_reschedule(self):
with freeze_time(date1):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
+ assert ti.state == State.UP_FOR_RESCHEDULE
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
# second poke returns False, timeout occurs
date2 = date1 + timedelta(seconds=sensor.poke_interval)
with freeze_time(date2):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.SKIPPED)
+ assert ti.state == State.SKIPPED
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
def test_ok_with_reschedule_and_retry(self):
sensor = self._make_sensor(
@@ -300,78 +297,74 @@ def test_ok_with_reschedule_and_retry(self):
with freeze_time(date1):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
+ assert ti.state == State.UP_FOR_RESCHEDULE
# verify one row in task_reschedule table
task_reschedules = TaskReschedule.find_for_task_instance(ti)
- self.assertEqual(len(task_reschedules), 1)
- self.assertEqual(task_reschedules[0].start_date, date1)
- self.assertEqual(
- task_reschedules[0].reschedule_date, date1 + timedelta(seconds=sensor.poke_interval)
- )
- self.assertEqual(task_reschedules[0].try_number, 1)
+ assert len(task_reschedules) == 1
+ assert task_reschedules[0].start_date == date1
+ assert task_reschedules[0].reschedule_date == date1 + timedelta(seconds=sensor.poke_interval)
+ assert task_reschedules[0].try_number == 1
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
# second poke fails and task instance is marked up to retry
date2 = date1 + timedelta(seconds=sensor.poke_interval)
with freeze_time(date2):
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.UP_FOR_RETRY)
+ assert ti.state == State.UP_FOR_RETRY
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
# third poke returns False and task is rescheduled again
date3 = date2 + timedelta(seconds=sensor.poke_interval) + sensor.retry_delay
with freeze_time(date3):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
+ assert ti.state == State.UP_FOR_RESCHEDULE
# verify one row in task_reschedule table
task_reschedules = TaskReschedule.find_for_task_instance(ti)
- self.assertEqual(len(task_reschedules), 1)
- self.assertEqual(task_reschedules[0].start_date, date3)
- self.assertEqual(
- task_reschedules[0].reschedule_date, date3 + timedelta(seconds=sensor.poke_interval)
- )
- self.assertEqual(task_reschedules[0].try_number, 2)
+ assert len(task_reschedules) == 1
+ assert task_reschedules[0].start_date == date3
+ assert task_reschedules[0].reschedule_date == date3 + timedelta(seconds=sensor.poke_interval)
+ assert task_reschedules[0].try_number == 2
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
# fourth poke return True and task succeeds
date4 = date3 + timedelta(seconds=sensor.poke_interval)
with freeze_time(date4):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
def test_should_include_ready_to_reschedule_dep_in_reschedule_mode(self):
sensor = self._make_sensor(True, mode='reschedule')
deps = sensor.deps
- self.assertIn(ReadyToRescheduleDep(), deps)
+ assert ReadyToRescheduleDep() in deps
def test_should_not_include_ready_to_reschedule_dep_in_poke_mode(self):
sensor = self._make_sensor(True)
deps = sensor.deps
- self.assertNotIn(ReadyToRescheduleDep(), deps)
+ assert ReadyToRescheduleDep() not in deps
def test_invalid_mode(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self._make_sensor(return_value=True, mode='foo')
def test_ok_with_custom_reschedule_exception(self):
@@ -392,46 +385,46 @@ def test_ok_with_custom_reschedule_exception(self):
with freeze_time(date1):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
# verify task is re-scheduled, i.e. state set to NONE
- self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
+ assert ti.state == State.UP_FOR_RESCHEDULE
# verify one row in task_reschedule table
task_reschedules = TaskReschedule.find_for_task_instance(ti)
- self.assertEqual(len(task_reschedules), 1)
- self.assertEqual(task_reschedules[0].start_date, date1)
- self.assertEqual(task_reschedules[0].reschedule_date, date2)
+ assert len(task_reschedules) == 1
+ assert task_reschedules[0].start_date == date1
+ assert task_reschedules[0].reschedule_date == date2
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
# second poke returns False and task is re-scheduled
with freeze_time(date2):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
# verify task is re-scheduled, i.e. state set to NONE
- self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
+ assert ti.state == State.UP_FOR_RESCHEDULE
# verify two rows in task_reschedule table
task_reschedules = TaskReschedule.find_for_task_instance(ti)
- self.assertEqual(len(task_reschedules), 2)
- self.assertEqual(task_reschedules[1].start_date, date2)
- self.assertEqual(task_reschedules[1].reschedule_date, date3)
+ assert len(task_reschedules) == 2
+ assert task_reschedules[1].start_date == date2
+ assert task_reschedules[1].reschedule_date == date3
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
# third poke returns True and task succeeds
with freeze_time(date3):
self._run(sensor)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
def test_reschedule_with_test_mode(self):
sensor = self._make_sensor(return_value=None, poke_interval=10, timeout=25, mode='reschedule')
@@ -444,22 +437,22 @@ def test_reschedule_with_test_mode(self):
for date in self.dag.date_range(DEFAULT_DATE, end_date=DEFAULT_DATE):
TaskInstance(sensor, date).run(ignore_ti_state=True, test_mode=True)
tis = dr.get_task_instances()
- self.assertEqual(len(tis), 2)
+ assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
# in test mode state is not modified
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
# in test mode no reschedule request is recorded
task_reschedules = TaskReschedule.find_for_task_instance(ti)
- self.assertEqual(len(task_reschedules), 0)
+ assert len(task_reschedules) == 0
if ti.task_id == DUMMY_OP:
- self.assertEqual(ti.state, State.NONE)
+ assert ti.state == State.NONE
def test_sensor_with_invalid_poke_interval(self):
negative_poke_interval = -10
non_number_poke_interval = "abcd"
positive_poke_interval = 10
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self._make_sensor(
task_id='test_sensor_task_1',
return_value=None,
@@ -467,7 +460,7 @@ def test_sensor_with_invalid_poke_interval(self):
timeout=25,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self._make_sensor(
task_id='test_sensor_task_2',
return_value=None,
@@ -483,12 +476,12 @@ def test_sensor_with_invalid_timeout(self):
negative_timeout = -25
non_number_timeout = "abcd"
positive_timeout = 25
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self._make_sensor(
task_id='test_sensor_task_1', return_value=None, poke_interval=10, timeout=negative_timeout
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
self._make_sensor(
task_id='test_sensor_task_2', return_value=None, poke_interval=10, timeout=non_number_timeout
)
@@ -505,8 +498,8 @@ def test_sensor_with_exponential_backoff_off(self):
def run_duration():
return (timezone.utcnow - started_at).total_seconds()
- self.assertEqual(sensor._get_next_poke_interval(started_at, run_duration, 1), sensor.poke_interval)
- self.assertEqual(sensor._get_next_poke_interval(started_at, run_duration, 2), sensor.poke_interval)
+ assert sensor._get_next_poke_interval(started_at, run_duration, 1) == sensor.poke_interval
+ assert sensor._get_next_poke_interval(started_at, run_duration, 2) == sensor.poke_interval
def test_sensor_with_exponential_backoff_on(self):
@@ -523,10 +516,10 @@ def run_duration():
interval1 = sensor._get_next_poke_interval(started_at, run_duration, 1)
interval2 = sensor._get_next_poke_interval(started_at, run_duration, 2)
- self.assertTrue(interval1 >= 0)
- self.assertTrue(interval1 <= sensor.poke_interval)
- self.assertTrue(interval2 >= sensor.poke_interval)
- self.assertTrue(interval2 > interval1)
+ assert interval1 >= 0
+ assert interval1 <= sensor.poke_interval
+ assert interval2 >= sensor.poke_interval
+ assert interval2 > interval1
@poke_mode_only
@@ -568,14 +561,14 @@ def test_poke_mode_only_allows_poke_mode(self):
def test_poke_mode_only_bad_class_method(self):
sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=False, dag=self.dag)
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
sensor.change_mode('reschedule')
def test_poke_mode_only_bad_init(self):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
DummyPokeOnlySensor(task_id='foo', mode='reschedule', poke_changes_mode=False, dag=self.dag)
def test_poke_mode_only_bad_poke(self):
sensor = DummyPokeOnlySensor(task_id='foo', mode='poke', poke_changes_mode=True, dag=self.dag)
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
sensor.poke({})
diff --git a/tests/sensors/test_bash.py b/tests/sensors/test_bash.py
index 934d67cd74ec4..4aad2e04ebab7 100644
--- a/tests/sensors/test_bash.py
+++ b/tests/sensors/test_bash.py
@@ -20,6 +20,8 @@
import datetime
import unittest
+import pytest
+
from airflow.exceptions import AirflowSensorTimeout
from airflow.models.dag import DAG
from airflow.sensors.bash import BashSensor
@@ -51,5 +53,5 @@ def test_false_condition(self):
timeout=2,
dag=self.dag,
)
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
op.execute(None)
diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py
index afa30e27f9d5f..55080c9667eb1 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -59,7 +59,7 @@ def test_external_task_sensor(self):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_catch_overlap_allowed_failed_state(self):
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
ExternalTaskSensor(
task_id='test_external_task_sensor_check',
external_dag_id=TEST_DAG_ID,
@@ -70,7 +70,7 @@ def test_catch_overlap_allowed_failed_state(self):
)
def test_external_task_sensor_wrong_failed_states(self):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
ExternalTaskSensor(
task_id='test_external_task_sensor_check',
external_dag_id=TEST_DAG_ID,
@@ -100,11 +100,9 @@ def test_external_task_sensor_failed_states_as_success(self):
failed_states=["success"],
dag=self.dag,
)
- with self.assertRaises(AirflowException) as cm:
+ with pytest.raises(AirflowException) as ctx:
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
- self.assertEqual(
- str(cm.exception), "The external task " "time_sensor_check in DAG " "unit_test_dag failed."
- )
+ assert str(ctx.value) == "The external task " "time_sensor_check in DAG " "unit_test_dag failed."
def test_external_dag_sensor(self):
other_dag = DAG('other_dag', default_args=self.args, end_date=DEFAULT_DATE, schedule_interval='@once')
@@ -128,8 +126,8 @@ def test_templated_sensor(self):
instance = TaskInstance(sensor, DEFAULT_DATE)
instance.render_templates()
- self.assertEqual(sensor.external_dag_id, f"dag_{DEFAULT_DATE.date()}")
- self.assertEqual(sensor.external_task_id, f"task_{DEFAULT_DATE.date()}")
+ assert sensor.external_dag_id == f"dag_{DEFAULT_DATE.date()}"
+ assert sensor.external_task_id == f"task_{DEFAULT_DATE.date()}"
def test_external_task_sensor_fn_multiple_execution_dates(self):
bash_command_code = """
@@ -205,7 +203,7 @@ def test_external_task_sensor_fn_multiple_execution_dates(self):
task_without_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task_with_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_delta(self):
@@ -243,7 +241,7 @@ def test_external_task_sensor_fn(self):
poke_interval=1,
dag=self.dag,
)
- with self.assertRaises(exceptions.AirflowSensorTimeout):
+ with pytest.raises(exceptions.AirflowSensorTimeout):
op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_fn_multiple_args(self):
@@ -286,7 +284,7 @@ def my_func(dt, ds_nodash, tomorrow_ds_nodash):
def test_external_task_sensor_error_delta_and_fn(self):
self.test_time_sensor()
# Test that providing execution_delta and a function raises an error
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
ExternalTaskSensor(
task_id='test_external_task_sensor_check_delta',
external_dag_id=TEST_DAG_ID,
@@ -298,7 +296,7 @@ def test_external_task_sensor_error_delta_and_fn(self):
)
def test_catch_invalid_allowed_states(self):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
ExternalTaskSensor(
task_id='test_external_task_sensor_check_1',
external_dag_id=TEST_DAG_ID,
@@ -307,7 +305,7 @@ def test_catch_invalid_allowed_states(self):
dag=self.dag,
)
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
ExternalTaskSensor(
task_id='test_external_task_sensor_check_2',
external_dag_id=TEST_DAG_ID,
@@ -325,7 +323,7 @@ def test_external_task_sensor_waits_for_task_check_existence(self):
dag=self.dag,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_waits_for_dag_check_existence(self):
@@ -337,13 +335,13 @@ def test_external_task_sensor_waits_for_dag_check_existence(self):
dag=self.dag,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
class TestExternalTaskMarker(unittest.TestCase):
def test_serialized_fields(self):
- self.assertTrue({"recursion_depth"}.issubset(ExternalTaskMarker.get_serialized_fields()))
+ assert {"recursion_depth"}.issubset(ExternalTaskMarker.get_serialized_fields())
def test_serialized_external_task_marker(self):
dag = DAG('test_serialized_external_task_marker', start_date=DEFAULT_DATE)
@@ -356,9 +354,9 @@ def test_serialized_external_task_marker(self):
serialized_op = SerializedBaseOperator.serialize_operator(task)
deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op)
- self.assertEqual(deserialized_op.task_type, 'ExternalTaskMarker')
- self.assertEqual(getattr(deserialized_op, 'external_dag_id'), 'external_task_marker_child')
- self.assertEqual(getattr(deserialized_op, 'external_task_id'), 'child_task1')
+ assert deserialized_op.task_type == 'ExternalTaskMarker'
+ assert getattr(deserialized_op, 'external_dag_id') == 'external_task_marker_child'
+ assert getattr(deserialized_op, 'external_task_id') == 'child_task1'
@pytest.fixture
diff --git a/tests/sensors/test_filesystem.py b/tests/sensors/test_filesystem.py
index 8332f9c553cf7..e197f13b79d6c 100644
--- a/tests/sensors/test_filesystem.py
+++ b/tests/sensors/test_filesystem.py
@@ -22,6 +22,8 @@
import tempfile
import unittest
+import pytest
+
from airflow.exceptions import AirflowSensorTimeout
from airflow.models.dag import DAG
from airflow.sensors.filesystem import FileSensor
@@ -65,7 +67,7 @@ def test_file_in_nonexistent_dir(self):
)
task._hook = self.hook
try:
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
finally:
shutil.rmtree(temp_dir)
@@ -82,7 +84,7 @@ def test_empty_dir(self):
)
task._hook = self.hook
try:
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
finally:
shutil.rmtree(temp_dir)
@@ -159,6 +161,6 @@ def test_subdirectory_empty(self):
)
task._hook = self.hook
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
shutil.rmtree(temp_dir)
diff --git a/tests/sensors/test_python.py b/tests/sensors/test_python.py
index c8eee8428420e..37f0114e99812 100644
--- a/tests/sensors/test_python.py
+++ b/tests/sensors/test_python.py
@@ -20,6 +20,8 @@
from collections import namedtuple
from datetime import date
+import pytest
+
from airflow.exceptions import AirflowSensorTimeout
from airflow.sensors.python import PythonSensor
from airflow.utils.state import State
@@ -43,12 +45,12 @@ def test_python_sensor_false(self):
python_callable=lambda: False,
dag=self.dag,
)
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_python_sensor_raise(self):
op = PythonSensor(task_id='python_sensor_check_raise', python_callable=lambda: 1 / 0, dag=self.dag)
- with self.assertRaises(ZeroDivisionError):
+ with pytest.raises(ZeroDivisionError):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_python_callable_arguments_are_templatized(self):
@@ -77,12 +79,12 @@ def test_python_callable_arguments_are_templatized(self):
start_date=DEFAULT_DATE,
state=State.RUNNING,
)
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
ds_templated = DEFAULT_DATE.date().isoformat()
# 2 calls: first: at start, second: before timeout
- self.assertEqual(2, len(recorded_calls))
+ assert 2 == len(recorded_calls)
self._assert_calls_equal(
recorded_calls[0],
Call(
@@ -118,11 +120,11 @@ def test_python_callable_keyword_arguments_are_templatized(self):
start_date=DEFAULT_DATE,
state=State.RUNNING,
)
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
# 2 calls: first: at start, second: before timeout
- self.assertEqual(2, len(recorded_calls))
+ assert 2 == len(recorded_calls)
self._assert_calls_equal(
recorded_calls[0],
Call(
diff --git a/tests/sensors/test_smart_sensor_operator.py b/tests/sensors/test_smart_sensor_operator.py
index 5b798cff86526..bec165bc1bc88 100644
--- a/tests/sensors/test_smart_sensor_operator.py
+++ b/tests/sensors/test_smart_sensor_operator.py
@@ -162,24 +162,24 @@ def test_load_sensor_works(self):
# Confirm initial state
smart = self._make_smart_operator(0)
smart.flush_cached_sensor_poke_results()
- self.assertEqual(len(smart.cached_dedup_works), 0)
- self.assertEqual(len(smart.cached_sensor_exceptions), 0)
+ assert len(smart.cached_dedup_works) == 0
+ assert len(smart.cached_sensor_exceptions) == 0
si1.run(ignore_all_deps=True)
# Test single sensor
smart._load_sensor_works()
- self.assertEqual(len(smart.sensor_works), 1)
- self.assertEqual(len(smart.cached_dedup_works), 0)
- self.assertEqual(len(smart.cached_sensor_exceptions), 0)
+ assert len(smart.sensor_works) == 1
+ assert len(smart.cached_dedup_works) == 0
+ assert len(smart.cached_sensor_exceptions) == 0
si2.run(ignore_all_deps=True)
si3.run(ignore_all_deps=True)
# Test multiple sensors with duplication
smart._load_sensor_works()
- self.assertEqual(len(smart.sensor_works), 3)
- self.assertEqual(len(smart.cached_dedup_works), 0)
- self.assertEqual(len(smart.cached_sensor_exceptions), 0)
+ assert len(smart.sensor_works) == 3
+ assert len(smart.cached_dedup_works) == 0
+ assert len(smart.cached_sensor_exceptions) == 0
def test_execute_single_task_with_dup(self):
sensor_dr = self._make_sensor_dag_run()
@@ -195,7 +195,7 @@ def test_execute_single_task_with_dup(self):
smart.flush_cached_sensor_poke_results()
smart._load_sensor_works()
- self.assertEqual(len(smart.sensor_works), 3)
+ assert len(smart.sensor_works) == 3
for sensor_work in smart.sensor_works:
_, task_id, _ = sensor_work.ti_key
@@ -203,16 +203,16 @@ def test_execute_single_task_with_dup(self):
smart._execute_sensor_work(sensor_work)
break
- self.assertEqual(len(smart.cached_dedup_works), 1)
+ assert len(smart.cached_dedup_works) == 1
tis = sensor_dr.get_task_instances()
for ti in tis:
if ti.task_id == SENSOR_OP + "1":
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
if ti.task_id == SENSOR_OP + "2":
- self.assertEqual(ti.state, State.SUCCESS)
+ assert ti.state == State.SUCCESS
if ti.task_id == SENSOR_OP + "3":
- self.assertEqual(ti.state, State.SENSING)
+ assert ti.state == State.SENSING
for sensor_work in smart.sensor_works:
_, task_id, _ = sensor_work.ti_key
@@ -220,7 +220,7 @@ def test_execute_single_task_with_dup(self):
smart._execute_sensor_work(sensor_work)
break
- self.assertEqual(len(smart.cached_dedup_works), 1)
+ assert len(smart.cached_dedup_works) == 1
time.sleep(1)
for sensor_work in smart.sensor_works:
@@ -229,13 +229,13 @@ def test_execute_single_task_with_dup(self):
smart._execute_sensor_work(sensor_work)
break
- self.assertEqual(len(smart.cached_dedup_works), 2)
+ assert len(smart.cached_dedup_works) == 2
tis = sensor_dr.get_task_instances()
for ti in tis:
# Timeout=0, the Failed poke lead to task fail
if ti.task_id == SENSOR_OP + "3":
- self.assertEqual(ti.state, State.FAILED)
+ assert ti.state == State.FAILED
def test_smart_operator_timeout(self):
sensor_dr = self._make_sensor_dag_run()
@@ -256,7 +256,7 @@ def test_smart_operator_timeout(self):
sis = sensor_dr.get_task_instances()
for sensor_instance in sis:
if sensor_instance.task_id == SENSOR_OP + "1":
- self.assertEqual(sensor_instance.state, State.SENSING)
+ assert sensor_instance.state == State.SENSING
date2 = date1 + datetime.timedelta(seconds=smart.poke_interval)
with freeze_time(date2):
@@ -269,7 +269,7 @@ def test_smart_operator_timeout(self):
sis = sensor_dr.get_task_instances()
for sensor_instance in sis:
if sensor_instance.task_id == SENSOR_OP + "1":
- self.assertEqual(sensor_instance.state, State.SENSING)
+ assert sensor_instance.state == State.SENSING
date3 = date2 + datetime.timedelta(seconds=smart.poke_interval)
with freeze_time(date3):
@@ -282,12 +282,12 @@ def test_smart_operator_timeout(self):
sis = sensor_dr.get_task_instances()
for sensor_instance in sis:
if sensor_instance.task_id == SENSOR_OP + "1":
- self.assertEqual(sensor_instance.state, State.FAILED)
+ assert sensor_instance.state == State.FAILED
def test_register_in_sensor_service(self):
si1 = self._make_sensor_instance(1, True)
si1.run(ignore_all_deps=True)
- self.assertEqual(si1.state, State.SENSING)
+ assert si1.state == State.SENSING
session = settings.Session()
@@ -300,6 +300,6 @@ def test_register_in_sensor_service(self):
.first()
)
- self.assertIsNotNone(sensor_instance)
- self.assertEqual(sensor_instance.state, State.SENSING)
- self.assertEqual(sensor_instance.operator, "DummySensor")
+ assert sensor_instance is not None
+ assert sensor_instance.state == State.SENSING
+ assert sensor_instance.operator == "DummySensor"
diff --git a/tests/sensors/test_sql_sensor.py b/tests/sensors/test_sql_sensor.py
index 88a185665c5b2..fca387b1748e3 100644
--- a/tests/sensors/test_sql_sensor.py
+++ b/tests/sensors/test_sql_sensor.py
@@ -45,7 +45,7 @@ def test_unsupported_conn_type(self):
dag=self.dag,
)
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@pytest.mark.backend("mysql")
@@ -98,25 +98,25 @@ def test_sql_sensor_postgres_poke(self, mock_hook):
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = []
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_get_records.return_value = [[None]]
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_get_records.return_value = [['None']]
- self.assertTrue(op.poke(None))
+ assert op.poke(None)
mock_get_records.return_value = [[0.0]]
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_get_records.return_value = [[0]]
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_get_records.return_value = [['0']]
- self.assertTrue(op.poke(None))
+ assert op.poke(None)
mock_get_records.return_value = [['1']]
- self.assertTrue(op.poke(None))
+ assert op.poke(None)
@mock.patch('airflow.sensors.sql.BaseHook')
def test_sql_sensor_postgres_poke_fail_on_empty(self, mock_hook):
@@ -128,7 +128,8 @@ def test_sql_sensor_postgres_poke_fail_on_empty(self, mock_hook):
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = []
- self.assertRaises(AirflowException, op.poke, None)
+ with pytest.raises(AirflowException):
+ op.poke(None)
@mock.patch('airflow.sensors.sql.BaseHook')
def test_sql_sensor_postgres_poke_success(self, mock_hook):
@@ -140,13 +141,13 @@ def test_sql_sensor_postgres_poke_success(self, mock_hook):
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = []
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_get_records.return_value = [[1]]
- self.assertTrue(op.poke(None))
+ assert op.poke(None)
mock_get_records.return_value = [['1']]
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
@mock.patch('airflow.sensors.sql.BaseHook')
def test_sql_sensor_postgres_poke_failure(self, mock_hook):
@@ -158,10 +159,11 @@ def test_sql_sensor_postgres_poke_failure(self, mock_hook):
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = []
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_get_records.return_value = [[1]]
- self.assertRaises(AirflowException, op.poke, None)
+ with pytest.raises(AirflowException):
+ op.poke(None)
@mock.patch('airflow.sensors.sql.BaseHook')
def test_sql_sensor_postgres_poke_failure_success(self, mock_hook):
@@ -177,13 +179,14 @@ def test_sql_sensor_postgres_poke_failure_success(self, mock_hook):
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = []
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_get_records.return_value = [[1]]
- self.assertRaises(AirflowException, op.poke, None)
+ with pytest.raises(AirflowException):
+ op.poke(None)
mock_get_records.return_value = [[2]]
- self.assertTrue(op.poke(None))
+ assert op.poke(None)
@mock.patch('airflow.sensors.sql.BaseHook')
def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook):
@@ -199,10 +202,11 @@ def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook):
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = []
- self.assertFalse(op.poke(None))
+ assert not op.poke(None)
mock_get_records.return_value = [[1]]
- self.assertRaises(AirflowException, op.poke, None)
+ with pytest.raises(AirflowException):
+ op.poke(None)
@mock.patch('airflow.sensors.sql.BaseHook')
def test_sql_sensor_postgres_poke_invalid_failure(self, mock_hook):
@@ -217,9 +221,9 @@ def test_sql_sensor_postgres_poke_invalid_failure(self, mock_hook):
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = [[1]]
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
op.poke(None)
- self.assertEqual("self.failure is present, but not callable -> [1]", str(e.exception))
+ assert "self.failure is present, but not callable -> [1]" == str(ctx.value)
@mock.patch('airflow.sensors.sql.BaseHook')
def test_sql_sensor_postgres_poke_invalid_success(self, mock_hook):
@@ -234,9 +238,9 @@ def test_sql_sensor_postgres_poke_invalid_success(self, mock_hook):
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = [[1]]
- with self.assertRaises(AirflowException) as e:
+ with pytest.raises(AirflowException) as ctx:
op.poke(None)
- self.assertEqual("self.success is present, but not callable -> [1]", str(e.exception))
+ assert "self.success is present, but not callable -> [1]" == str(ctx.value)
@unittest.skipIf(
'AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set"
diff --git a/tests/sensors/test_timeout_sensor.py b/tests/sensors/test_timeout_sensor.py
index 1be520e251d82..63b42f42a2ef9 100644
--- a/tests/sensors/test_timeout_sensor.py
+++ b/tests/sensors/test_timeout_sensor.py
@@ -19,6 +19,8 @@
import unittest
from datetime import timedelta
+import pytest
+
from airflow.exceptions import AirflowSensorTimeout, AirflowSkipException
from airflow.models.dag import DAG
from airflow.sensors.base import BaseSensorOperator
@@ -75,6 +77,5 @@ def test_timeout(self):
params={'time_jump': timedelta(days=2, seconds=1)},
dag=self.dag,
)
- self.assertRaises(
- AirflowSensorTimeout, op.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True
- )
+ with pytest.raises(AirflowSensorTimeout):
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
diff --git a/tests/sensors/test_weekday_sensor.py b/tests/sensors/test_weekday_sensor.py
index f632e82b28453..1faf1ceea8a5a 100644
--- a/tests/sensors/test_weekday_sensor.py
+++ b/tests/sensors/test_weekday_sensor.py
@@ -19,6 +19,7 @@
import unittest
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowSensorTimeout
@@ -67,7 +68,7 @@ def test_weekday_sensor_true(self, _, week_day):
task_id='weekday_sensor_check_true', week_day=week_day, use_task_execution_day=True, dag=self.dag
)
op.run(start_date=WEEKDAY_DATE, end_date=WEEKDAY_DATE, ignore_ti_state=True)
- self.assertEqual(op.week_day, week_day)
+ assert op.week_day == week_day
def test_weekday_sensor_false(self):
op = DayOfWeekSensor(
@@ -78,12 +79,12 @@ def test_weekday_sensor_false(self):
use_task_execution_day=True,
dag=self.dag,
)
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
op.run(start_date=WEEKDAY_DATE, end_date=WEEKDAY_DATE, ignore_ti_state=True)
def test_invalid_weekday_number(self):
invalid_week_day = 'Thsday'
- with self.assertRaisesRegex(AttributeError, f'Invalid Week Day passed: "{invalid_week_day}"'):
+ with pytest.raises(AttributeError, match=f'Invalid Week Day passed: "{invalid_week_day}"'):
DayOfWeekSensor(
task_id='weekday_sensor_invalid_weekday_num',
week_day=invalid_week_day,
@@ -93,9 +94,9 @@ def test_invalid_weekday_number(self):
def test_weekday_sensor_with_invalid_type(self):
invalid_week_day = ['Thsday']
- with self.assertRaisesRegex(
+ with pytest.raises(
TypeError,
- 'Unsupported Type for week_day parameter:'
+ match='Unsupported Type for week_day parameter:'
' {}. It should be one of str, set or '
'Weekday enum type'.format(type(invalid_week_day)),
):
@@ -115,5 +116,5 @@ def test_weekday_sensor_timeout_with_set(self):
use_task_execution_day=True,
dag=self.dag,
)
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout):
op.run(start_date=WEEKDAY_DATE, end_date=WEEKDAY_DATE, ignore_ti_state=True)
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index eba7f131aecdf..1b7524acfe17e 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -273,7 +273,7 @@ def test_serialization(self):
def validate_serialized_dag(self, json_dag, ground_truth_dag):
"""Verify serialized DAGs match the ground truth."""
- self.assertTrue(json_dag['dag']['fileloc'].split('/')[-1] == 'test_dag_serialization.py')
+ assert json_dag['dag']['fileloc'].split('/')[-1] == 'test_dag_serialization.py'
json_dag['dag']['fileloc'] = None
def sorted_serialized_dag(dag_dict: dict):
@@ -309,7 +309,7 @@ def test_deserialization_across_process(self):
if v is None:
break
dag = SerializedDAG.from_json(v)
- self.assertTrue(isinstance(dag, DAG))
+ assert isinstance(dag, DAG)
stringified_dags[dag.dag_id] = dag
dags = collect_dags("airflow/example_dags")
@@ -441,13 +441,13 @@ def test_deserialization_start_date(self, dag_start_date, task_start_date, expec
if not task_start_date or dag_start_date >= task_start_date:
# If dag.start_date > task.start_date -> task.start_date=dag.start_date
# because of the logic in dag.add_task()
- self.assertNotIn("start_date", serialized_dag["dag"]["tasks"][0])
+ assert "start_date" not in serialized_dag["dag"]["tasks"][0]
else:
- self.assertIn("start_date", serialized_dag["dag"]["tasks"][0])
+ assert "start_date" in serialized_dag["dag"]["tasks"][0]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict["simple_task"]
- self.assertEqual(simple_task.start_date, expected_task_start_date)
+ assert simple_task.start_date == expected_task_start_date
def test_deserialization_with_dag_context(self):
with DAG(dag_id='simple_dag', start_date=datetime(2019, 8, 1, tzinfo=timezone.utc)) as dag:
@@ -478,13 +478,13 @@ def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_ta
if not task_end_date or dag_end_date <= task_end_date:
# If dag.end_date < task.end_date -> task.end_date=dag.end_date
# because of the logic in dag.add_task()
- self.assertNotIn("end_date", serialized_dag["dag"]["tasks"][0])
+ assert "end_date" not in serialized_dag["dag"]["tasks"][0]
else:
- self.assertIn("end_date", serialized_dag["dag"]["tasks"][0])
+ assert "end_date" in serialized_dag["dag"]["tasks"][0]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict["simple_task"]
- self.assertEqual(simple_task.end_date, expected_task_end_date)
+ assert simple_task.end_date == expected_task_end_date
@parameterized.expand(
[
@@ -513,8 +513,8 @@ def test_deserialization_schedule_interval(
dag = SerializedDAG.from_dict(serialized)
- self.assertEqual(dag.schedule_interval, expected_schedule_interval)
- self.assertEqual(dag.normalized_schedule_interval, expected_n_schedule_interval)
+ assert dag.schedule_interval == expected_schedule_interval
+ assert dag.normalized_schedule_interval == expected_n_schedule_interval
@parameterized.expand(
[
@@ -528,10 +528,10 @@ def test_deserialization_schedule_interval(
)
def test_roundtrip_relativedelta(self, val, expected):
serialized = SerializedDAG._serialize(val)
- self.assertDictEqual(serialized, expected)
+ assert serialized == expected
round_tripped = SerializedDAG._deserialize(serialized)
- self.assertEqual(val, round_tripped)
+ assert val == round_tripped
@parameterized.expand(
[
@@ -548,14 +548,14 @@ def test_dag_params_roundtrip(self, val, expected_val):
serialized_dag = SerializedDAG.to_dict(dag)
if val:
- self.assertIn("params", serialized_dag["dag"])
+ assert "params" in serialized_dag["dag"]
else:
- self.assertNotIn("params", serialized_dag["dag"])
+ assert "params" not in serialized_dag["dag"]
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_simple_task = deserialized_dag.task_dict["simple_task"]
- self.assertEqual(expected_val, deserialized_dag.params)
- self.assertEqual(expected_val, deserialized_simple_task.params)
+ assert expected_val == deserialized_dag.params
+ assert expected_val == deserialized_simple_task.params
@parameterized.expand(
[
@@ -572,13 +572,13 @@ def test_task_params_roundtrip(self, val, expected_val):
serialized_dag = SerializedDAG.to_dict(dag)
if val:
- self.assertIn("params", serialized_dag["dag"]["tasks"][0])
+ assert "params" in serialized_dag["dag"]["tasks"][0]
else:
- self.assertNotIn("params", serialized_dag["dag"]["tasks"][0])
+ assert "params" not in serialized_dag["dag"]["tasks"][0]
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_simple_task = deserialized_dag.task_dict["simple_task"]
- self.assertEqual(expected_val, deserialized_simple_task.params)
+ assert expected_val == deserialized_simple_task.params
def test_extra_serialized_field_and_operator_links(self):
"""
@@ -597,34 +597,33 @@ def test_extra_serialized_field_and_operator_links(self):
CustomOperator(task_id='simple_task', dag=dag, bash_command="true")
serialized_dag = SerializedDAG.to_dict(dag)
- self.assertIn("bash_command", serialized_dag["dag"]["tasks"][0])
+ assert "bash_command" in serialized_dag["dag"]["tasks"][0]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict["simple_task"]
- self.assertEqual(getattr(simple_task, "bash_command"), "true")
+ assert getattr(simple_task, "bash_command") == "true"
#########################################################
# Verify Operator Links work with Serialized Operator
#########################################################
# Check Serialized version of operator link only contains the inbuilt Op Link
- self.assertEqual(
- serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
- [{'tests.test_utils.mock_operators.CustomOpLink': {}}],
- )
+ assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ {'tests.test_utils.mock_operators.CustomOpLink': {}}
+ ]
# Test all the extra_links are set
- self.assertCountEqual(simple_task.extra_links, ['Google Custom', 'airflow', 'github', 'google'])
+ assert set(simple_task.extra_links) == {'Google Custom', 'airflow', 'github', 'google'}
ti = TaskInstance(task=simple_task, execution_date=test_date)
ti.xcom_push('search_query', "dummy_value_1")
# Test Deserialized inbuilt link
custom_inbuilt_link = simple_task.get_extra_links(test_date, CustomOpLink.name)
- self.assertEqual('http://google.com/custom_base_link?search=dummy_value_1', custom_inbuilt_link)
+ assert 'http://google.com/custom_base_link?search=dummy_value_1' == custom_inbuilt_link
# Test Deserialized link registered via Airflow Plugin
google_link_from_plugin = simple_task.get_extra_links(test_date, GoogleLink.name)
- self.assertEqual("https://www.google.com", google_link_from_plugin)
+ assert "https://www.google.com" == google_link_from_plugin
def test_extra_operator_links_logs_error_for_non_registered_extra_links(self):
"""
@@ -679,44 +678,44 @@ def test_extra_serialized_field_and_multiple_operator_links(self):
CustomOperator(task_id='simple_task', dag=dag, bash_command=["echo", "true"])
serialized_dag = SerializedDAG.to_dict(dag)
- self.assertIn("bash_command", serialized_dag["dag"]["tasks"][0])
+ assert "bash_command" in serialized_dag["dag"]["tasks"][0]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict["simple_task"]
- self.assertEqual(getattr(simple_task, "bash_command"), ["echo", "true"])
+ assert getattr(simple_task, "bash_command") == ["echo", "true"]
#########################################################
# Verify Operator Links work with Serialized Operator
#########################################################
# Check Serialized version of operator link only contains the inbuilt Op Link
- self.assertEqual(
- serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
- [
- {'tests.test_utils.mock_operators.CustomBaseIndexOpLink': {'index': 0}},
- {'tests.test_utils.mock_operators.CustomBaseIndexOpLink': {'index': 1}},
- ],
- )
+ assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ {'tests.test_utils.mock_operators.CustomBaseIndexOpLink': {'index': 0}},
+ {'tests.test_utils.mock_operators.CustomBaseIndexOpLink': {'index': 1}},
+ ]
# Test all the extra_links are set
- self.assertCountEqual(
- simple_task.extra_links,
- ['BigQuery Console #1', 'BigQuery Console #2', 'airflow', 'github', 'google'],
- )
+ assert set(simple_task.extra_links) == {
+ 'BigQuery Console #1',
+ 'BigQuery Console #2',
+ 'airflow',
+ 'github',
+ 'google',
+ }
ti = TaskInstance(task=simple_task, execution_date=test_date)
ti.xcom_push('search_query', ["dummy_value_1", "dummy_value_2"])
# Test Deserialized inbuilt link #1
custom_inbuilt_link = simple_task.get_extra_links(test_date, "BigQuery Console #1")
- self.assertEqual('https://console.cloud.google.com/bigquery?j=dummy_value_1', custom_inbuilt_link)
+ assert 'https://console.cloud.google.com/bigquery?j=dummy_value_1' == custom_inbuilt_link
# Test Deserialized inbuilt link #2
custom_inbuilt_link = simple_task.get_extra_links(test_date, "BigQuery Console #2")
- self.assertEqual('https://console.cloud.google.com/bigquery?j=dummy_value_2', custom_inbuilt_link)
+ assert 'https://console.cloud.google.com/bigquery?j=dummy_value_2' == custom_inbuilt_link
# Test Deserialized link registered via Airflow Plugin
google_link_from_plugin = simple_task.get_extra_links(test_date, GoogleLink.name)
- self.assertEqual("https://www.google.com", google_link_from_plugin)
+ assert "https://www.google.com" == google_link_from_plugin
class ClassWithCustomAttributes:
"""
@@ -796,7 +795,7 @@ def test_templated_fields_exist_in_serialized_dag(self, templated_field, expecte
serialized_dag = SerializedDAG.to_dict(dag)
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_test_task = deserialized_dag.task_dict["test"]
- self.assertEqual(expected_field, getattr(deserialized_test_task, "bash_command"))
+ assert expected_field == getattr(deserialized_test_task, "bash_command")
def test_dag_serialized_fields_with_schema(self):
"""
@@ -808,7 +807,7 @@ def test_dag_serialized_fields_with_schema(self):
# The parameters we add manually in Serialization needs to be ignored
ignored_keys: set = {"is_subdag", "tasks", "has_on_success_callback", "has_on_failure_callback"}
dag_params: set = set(dag_schema.keys()) - ignored_keys
- self.assertEqual(set(DAG.get_serialized_fields()), dag_params)
+ assert set(DAG.get_serialized_fields()) == dag_params
def test_operator_subclass_changing_base_defaults(self):
assert (
@@ -835,53 +834,50 @@ def test_no_new_fields_added_to_base_operator(self):
"""
base_operator = BaseOperator(task_id="10")
fields = base_operator.__dict__
- self.assertEqual(
- {
- '_BaseOperator__instantiated': True,
- '_dag': None,
- '_downstream_task_ids': set(),
- '_inlets': [],
- '_log': base_operator.log,
- '_outlets': [],
- '_upstream_task_ids': set(),
- 'depends_on_past': False,
- 'do_xcom_push': True,
- 'email': None,
- 'email_on_failure': True,
- 'email_on_retry': True,
- 'end_date': None,
- 'execution_timeout': None,
- 'executor_config': {},
- 'inlets': [],
- 'label': '10',
- 'max_retry_delay': None,
- 'on_execute_callback': None,
- 'on_failure_callback': None,
- 'on_retry_callback': None,
- 'on_success_callback': None,
- 'outlets': [],
- 'owner': 'airflow',
- 'params': {},
- 'pool': 'default_pool',
- 'pool_slots': 1,
- 'priority_weight': 1,
- 'queue': 'default',
- 'resources': None,
- 'retries': 0,
- 'retry_delay': timedelta(0, 300),
- 'retry_exponential_backoff': False,
- 'run_as_user': None,
- 'sla': None,
- 'start_date': None,
- 'subdag': None,
- 'task_concurrency': None,
- 'task_id': '10',
- 'trigger_rule': 'all_success',
- 'wait_for_downstream': False,
- 'weight_rule': 'downstream',
- },
- fields,
- """
+ assert {
+ '_BaseOperator__instantiated': True,
+ '_dag': None,
+ '_downstream_task_ids': set(),
+ '_inlets': [],
+ '_log': base_operator.log,
+ '_outlets': [],
+ '_upstream_task_ids': set(),
+ 'depends_on_past': False,
+ 'do_xcom_push': True,
+ 'email': None,
+ 'email_on_failure': True,
+ 'email_on_retry': True,
+ 'end_date': None,
+ 'execution_timeout': None,
+ 'executor_config': {},
+ 'inlets': [],
+ 'label': '10',
+ 'max_retry_delay': None,
+ 'on_execute_callback': None,
+ 'on_failure_callback': None,
+ 'on_retry_callback': None,
+ 'on_success_callback': None,
+ 'outlets': [],
+ 'owner': 'airflow',
+ 'params': {},
+ 'pool': 'default_pool',
+ 'pool_slots': 1,
+ 'priority_weight': 1,
+ 'queue': 'default',
+ 'resources': None,
+ 'retries': 0,
+ 'retry_delay': timedelta(0, 300),
+ 'retry_exponential_backoff': False,
+ 'run_as_user': None,
+ 'sla': None,
+ 'start_date': None,
+ 'subdag': None,
+ 'task_concurrency': None,
+ 'task_id': '10',
+ 'trigger_rule': 'all_success',
+ 'wait_for_downstream': False,
+ 'weight_rule': 'downstream',
+ } == fields, """
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
ACTION NEEDED! PLEASE READ THIS CAREFULLY AND CORRECT TESTS CAREFULLY
@@ -893,8 +889,7 @@ def test_no_new_fields_added_to_base_operator(self):
Note that we do not support versioning yet so you should only add optional fields to BaseOperator.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
- """,
- )
+ """
def test_task_group_serialization(self):
"""
diff --git a/tests/task/task_runner/test_cgroup_task_runner.py b/tests/task/task_runner/test_cgroup_task_runner.py
index c5fb97098abe0..61f525f38fab0 100644
--- a/tests/task/task_runner/test_cgroup_task_runner.py
+++ b/tests/task/task_runner/test_cgroup_task_runner.py
@@ -37,7 +37,7 @@ def test_cgroup_task_runner_super_calls(self, mock_super_on_finish, mock_super_i
local_task_job.task_instance.command_as_list.return_value = ['sleep', '1000']
runner = CgroupTaskRunner(local_task_job)
- self.assertTrue(mock_super_init.called)
+ assert mock_super_init.called
runner.on_finish()
- self.assertTrue(mock_super_on_finish.called)
+ assert mock_super_on_finish.called
diff --git a/tests/task/task_runner/test_standard_task_runner.py b/tests/task/task_runner/test_standard_task_runner.py
index a442d05e03eaa..fcd4948477289 100644
--- a/tests/task/task_runner/test_standard_task_runner.py
+++ b/tests/task/task_runner/test_standard_task_runner.py
@@ -88,17 +88,17 @@ def test_start_and_terminate(self):
time.sleep(0.5)
pgid = os.getpgid(runner.process.pid)
- self.assertGreater(pgid, 0)
- self.assertNotEqual(pgid, os.getpgid(0), "Task should be in a different process group to us")
+ assert pgid > 0
+ assert pgid != os.getpgid(0), "Task should be in a different process group to us"
processes = list(self._procs_in_pgroup(pgid))
runner.terminate()
for process in processes:
- self.assertFalse(psutil.pid_exists(process.pid), f"{process} is still alive")
+ assert not psutil.pid_exists(process.pid), f"{process} is still alive"
- self.assertIsNotNone(runner.return_code())
+ assert runner.return_code() is not None
def test_start_and_terminate_run_as_user(self):
local_task_job = mock.Mock()
@@ -119,17 +119,17 @@ def test_start_and_terminate_run_as_user(self):
time.sleep(0.5)
pgid = os.getpgid(runner.process.pid)
- self.assertGreater(pgid, 0)
- self.assertNotEqual(pgid, os.getpgid(0), "Task should be in a different process group to us")
+ assert pgid > 0
+ assert pgid != os.getpgid(0), "Task should be in a different process group to us"
processes = list(self._procs_in_pgroup(pgid))
runner.terminate()
for process in processes:
- self.assertFalse(psutil.pid_exists(process.pid), f"{process} is still alive")
+ assert not psutil.pid_exists(process.pid), f"{process} is still alive"
- self.assertIsNotNone(runner.return_code())
+ assert runner.return_code() is not None
def test_on_kill(self):
"""
@@ -170,8 +170,8 @@ def test_on_kill(self):
time.sleep(3)
pgid = os.getpgid(runner.process.pid)
- self.assertGreater(pgid, 0)
- self.assertNotEqual(pgid, os.getpgid(0), "Task should be in a different process group to us")
+ assert pgid > 0
+ assert pgid != os.getpgid(0), "Task should be in a different process group to us"
processes = list(self._procs_in_pgroup(pgid))
@@ -184,10 +184,10 @@ def test_on_kill(self):
time.sleep(2)
with open(path) as f:
- self.assertEqual("ON_KILL_TEST", f.readline())
+ assert "ON_KILL_TEST" == f.readline()
for process in processes:
- self.assertFalse(psutil.pid_exists(process.pid), f"{process} is still alive")
+ assert not psutil.pid_exists(process.pid), f"{process} is still alive"
@staticmethod
def _procs_in_pgroup(pgid):
diff --git a/tests/task/task_runner/test_task_runner.py b/tests/task/task_runner/test_task_runner.py
index 4601333ee18ea..b5fd73f42b29e 100644
--- a/tests/task/task_runner/test_task_runner.py
+++ b/tests/task/task_runner/test_task_runner.py
@@ -29,7 +29,7 @@
class GetTaskRunner(unittest.TestCase):
@parameterized.expand([(import_path,) for import_path in CORE_TASK_RUNNERS.values()])
def test_should_have_valid_imports(self, import_path):
- self.assertIsNotNone(import_string(import_path))
+ assert import_string(import_path) is not None
@mock.patch('airflow.task.task_runner.base_task_runner.subprocess')
@mock.patch('airflow.task.task_runner._TASK_RUNNER_NAME', "StandardTaskRunner")
@@ -39,7 +39,7 @@ def test_should_support_core_task_runner(self, mock_subprocess):
)
task_runner = get_task_runner(local_task_job)
- self.assertEqual("StandardTaskRunner", task_runner.__class__.__name__)
+ assert "StandardTaskRunner" == task_runner.__class__.__name__
@mock.patch(
'airflow.task.task_runner._TASK_RUNNER_NAME',
@@ -55,4 +55,4 @@ def test_should_support_custom_task_runner(self):
custom_task_runner.assert_called_once_with(local_task_job)
- self.assertEqual(custom_task_runner.return_value, task_runner)
+ assert custom_task_runner.return_value == task_runner
diff --git a/tests/test_utils/perf/perf_kit/__init__.py b/tests/test_utils/perf/perf_kit/__init__.py
index 0b4e344b343d1..25583fcfcf347 100644
--- a/tests/test_utils/perf/perf_kit/__init__.py
+++ b/tests/test_utils/perf/perf_kit/__init__.py
@@ -69,8 +69,8 @@
prev = dag.previous_schedule(_next)
prev_local = local_tz.convert(prev)
- self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00")
- self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00")
+ assert prev_local.isoformat() == "2018-03-24T03:00:00+01:00"
+ assert prev.isoformat() == "2018-03-24T02:00:00+00:00"
def test_bulk_write_to_db(self):
clear_db_dags()
@@ -90,8 +90,8 @@ def test_bulk_write_to_db(self):
prev = dag.previous_schedule(_next)
prev_local = local_tz.convert(prev)
- self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00")
- self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00")
+ assert prev_local.isoformat() == "2018-03-24T03:00:00+01:00"
+ assert prev.isoformat() == "2018-03-24T02:00:00+00:00"
from tests.utils.perf.perf_kit.sqlalchemy import trace_queries
diff --git a/tests/test_utils/reset_warning_registry.py b/tests/test_utils/reset_warning_registry.py
index 42a9981c12211..0b3b2f2a18934 100644
--- a/tests/test_utils/reset_warning_registry.py
+++ b/tests/test_utils/reset_warning_registry.py
@@ -58,7 +58,7 @@ def __enter__(self):
for name, mod in list(sys.modules.items()):
if pattern.match(name):
reg = getattr(mod, "__warningregistry__", None)
- if reg:
+ if reg and isinstance(reg, dict):
backup[name] = reg.copy()
reg.clear()
return self
@@ -83,5 +83,5 @@ def __exit__(self, *exc_info):
for name, mod in list(modules.items()):
if pattern.match(name) and name not in backup:
reg = getattr(mod, "__warningregistry__", None)
- if reg:
+ if reg and isinstance(reg, dict):
reg.clear()
diff --git a/tests/test_utils/test_remote_user_api_auth_backend.py b/tests/test_utils/test_remote_user_api_auth_backend.py
index 8083ed0611dc4..04da30adfac3b 100644
--- a/tests/test_utils/test_remote_user_api_auth_backend.py
+++ b/tests/test_utils/test_remote_user_api_auth_backend.py
@@ -52,10 +52,10 @@ def test_success_using_username(self):
with self.app.test_client() as test_client:
response = test_client.get("/api/experimental/pools", environ_overrides={'REMOTE_USER': "test"})
- self.assertEqual("test@fab.org", current_user.email)
+ assert "test@fab.org" == current_user.email
- self.assertEqual(200, response.status_code)
- self.assertIn("Default pool", str(response.json))
+ assert 200 == response.status_code
+ assert "Default pool" in str(response.json)
def test_success_using_email(self):
clear_db_pools()
@@ -64,10 +64,10 @@ def test_success_using_email(self):
response = test_client.get(
"/api/experimental/pools", environ_overrides={'REMOTE_USER': "test@fab.org"}
)
- self.assertEqual("test@fab.org", current_user.email)
+ assert "test@fab.org" == current_user.email
- self.assertEqual(200, response.status_code)
- self.assertIn("Default pool", str(response.json))
+ assert 200 == response.status_code
+ assert "Default pool" in str(response.json)
def test_user_not_exists(self):
with self.app.test_client() as test_client:
@@ -75,15 +75,15 @@ def test_user_not_exists(self):
"/api/experimental/pools", environ_overrides={'REMOTE_USER': "INVALID"}
)
- self.assertEqual(403, response.status_code)
- self.assertEqual("Forbidden", response.data.decode())
+ assert 403 == response.status_code
+ assert "Forbidden" == response.data.decode()
def test_missing_remote_user(self):
with self.app.test_client() as test_client:
response = test_client.get("/api/experimental/pools")
- self.assertEqual(403, response.status_code)
- self.assertEqual("Forbidden", response.data.decode())
+ assert 403 == response.status_code
+ assert "Forbidden" == response.data.decode()
def test_user_should_be_logged_temporary(self):
with self.app.test_client() as test_client:
@@ -91,5 +91,5 @@ def test_user_should_be_logged_temporary(self):
"/api/experimental/pools", environ_overrides={'REMOTE_USER': "INVALID"}
)
- self.assertEqual(403, response.status_code)
- self.assertEqual("Forbidden", response.data.decode())
+ assert 403 == response.status_code
+ assert "Forbidden" == response.data.decode()
diff --git a/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py b/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py
index f6fba610f6710..e1a570e23ebc9 100644
--- a/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py
+++ b/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py
@@ -32,7 +32,7 @@ def test_concurrency_reached(self):
task = Mock(dag=dag, pool_slots=1)
ti = TaskInstance(task, execution_date=None)
- self.assertFalse(DagTISlotsAvailableDep().is_met(ti=ti))
+ assert not DagTISlotsAvailableDep().is_met(ti=ti)
def test_all_conditions_met(self):
"""
@@ -42,4 +42,4 @@ def test_all_conditions_met(self):
task = Mock(dag=dag, pool_slots=1)
ti = TaskInstance(task, execution_date=None)
- self.assertTrue(DagTISlotsAvailableDep().is_met(ti=ti))
+ assert DagTISlotsAvailableDep().is_met(ti=ti)
diff --git a/tests/ti_deps/deps/test_dag_unpaused_dep.py b/tests/ti_deps/deps/test_dag_unpaused_dep.py
index 1303f616c9370..6fbfed70d6176 100644
--- a/tests/ti_deps/deps/test_dag_unpaused_dep.py
+++ b/tests/ti_deps/deps/test_dag_unpaused_dep.py
@@ -32,7 +32,7 @@ def test_concurrency_reached(self):
task = Mock(dag=dag)
ti = TaskInstance(task=task, execution_date=None)
- self.assertFalse(DagUnpausedDep().is_met(ti=ti))
+ assert not DagUnpausedDep().is_met(ti=ti)
def test_all_conditions_met(self):
"""
@@ -42,4 +42,4 @@ def test_all_conditions_met(self):
task = Mock(dag=dag)
ti = TaskInstance(task=task, execution_date=None)
- self.assertTrue(DagUnpausedDep().is_met(ti=ti))
+ assert DagUnpausedDep().is_met(ti=ti)
diff --git a/tests/ti_deps/deps/test_dagrun_exists_dep.py b/tests/ti_deps/deps/test_dagrun_exists_dep.py
index a1696c0794810..65179f5774805 100644
--- a/tests/ti_deps/deps/test_dagrun_exists_dep.py
+++ b/tests/ti_deps/deps/test_dagrun_exists_dep.py
@@ -32,7 +32,7 @@ def test_dagrun_doesnt_exist(self, mock_dagrun_find):
"""
dag = DAG('test_dag', max_active_runs=2)
ti = Mock(task=Mock(dag=dag), get_dagrun=Mock(return_value=None))
- self.assertFalse(DagrunRunningDep().is_met(ti=ti))
+ assert not DagrunRunningDep().is_met(ti=ti)
def test_dagrun_exists(self):
"""
@@ -40,4 +40,4 @@ def test_dagrun_exists(self):
"""
dagrun = DagRun(state=State.RUNNING)
ti = Mock(get_dagrun=Mock(return_value=dagrun))
- self.assertTrue(DagrunRunningDep().is_met(ti=ti))
+ assert DagrunRunningDep().is_met(ti=ti)
diff --git a/tests/ti_deps/deps/test_dagrun_id_dep.py b/tests/ti_deps/deps/test_dagrun_id_dep.py
index 04fdc462ee6ca..c05d9c4a3fce2 100644
--- a/tests/ti_deps/deps/test_dagrun_id_dep.py
+++ b/tests/ti_deps/deps/test_dagrun_id_dep.py
@@ -33,7 +33,7 @@ def test_dagrun_id_is_backfill(self):
dagrun.run_id = "anything"
dagrun.run_type = DagRunType.BACKFILL_JOB
ti = Mock(get_dagrun=Mock(return_value=dagrun))
- self.assertFalse(DagrunIdDep().is_met(ti=ti))
+ assert not DagrunIdDep().is_met(ti=ti)
def test_dagrun_id_is_not_backfill(self):
"""
@@ -42,16 +42,16 @@ def test_dagrun_id_is_not_backfill(self):
dagrun = DagRun()
dagrun.run_type = 'custom_type'
ti = Mock(get_dagrun=Mock(return_value=dagrun))
- self.assertTrue(DagrunIdDep().is_met(ti=ti))
+ assert DagrunIdDep().is_met(ti=ti)
dagrun = DagRun()
dagrun.run_id = None
ti = Mock(get_dagrun=Mock(return_value=dagrun))
- self.assertTrue(DagrunIdDep().is_met(ti=ti))
+ assert DagrunIdDep().is_met(ti=ti)
def test_dagrun_is_none(self):
"""
Task instances which don't yet have an associated dagrun.
"""
ti = Mock(get_dagrun=Mock(return_value=None))
- self.assertTrue(DagrunIdDep().is_met(ti=ti))
+ assert DagrunIdDep().is_met(ti=ti)
diff --git a/tests/ti_deps/deps/test_not_in_retry_period_dep.py b/tests/ti_deps/deps/test_not_in_retry_period_dep.py
index 2fd28d32e9675..7e393918f1cd7 100644
--- a/tests/ti_deps/deps/test_not_in_retry_period_dep.py
+++ b/tests/ti_deps/deps/test_not_in_retry_period_dep.py
@@ -41,8 +41,8 @@ def test_still_in_retry_period(self):
Task instances that are in their retry period should fail this dep
"""
ti = self._get_task_instance(State.UP_FOR_RETRY, end_date=datetime(2016, 1, 1, 15, 30))
- self.assertTrue(ti.is_premature)
- self.assertFalse(NotInRetryPeriodDep().is_met(ti=ti))
+ assert ti.is_premature
+ assert not NotInRetryPeriodDep().is_met(ti=ti)
@freeze_time('2016-01-01 15:46')
def test_retry_period_finished(self):
@@ -50,12 +50,12 @@ def test_retry_period_finished(self):
Task instance's that have had their retry period elapse should pass this dep
"""
ti = self._get_task_instance(State.UP_FOR_RETRY, end_date=datetime(2016, 1, 1))
- self.assertFalse(ti.is_premature)
- self.assertTrue(NotInRetryPeriodDep().is_met(ti=ti))
+ assert not ti.is_premature
+ assert NotInRetryPeriodDep().is_met(ti=ti)
def test_not_in_retry_period(self):
"""
Task instance's that are not up for retry can not be in their retry period
"""
ti = self._get_task_instance(State.SUCCESS)
- self.assertTrue(NotInRetryPeriodDep().is_met(ti=ti))
+ assert NotInRetryPeriodDep().is_met(ti=ti)
diff --git a/tests/ti_deps/deps/test_pool_slots_available_dep.py b/tests/ti_deps/deps/test_pool_slots_available_dep.py
index 927e315c74740..bcc22951f74de 100644
--- a/tests/ti_deps/deps/test_pool_slots_available_dep.py
+++ b/tests/ti_deps/deps/test_pool_slots_available_dep.py
@@ -40,21 +40,21 @@ def tearDown(self):
# pylint: disable=unused-argument
def test_pooled_task_reached_concurrency(self, mock_open_slots):
ti = Mock(pool='test_pool', pool_slots=1)
- self.assertFalse(PoolSlotsAvailableDep().is_met(ti=ti))
+ assert not PoolSlotsAvailableDep().is_met(ti=ti)
@patch('airflow.models.Pool.open_slots', return_value=1)
# pylint: disable=unused-argument
def test_pooled_task_pass(self, mock_open_slots):
ti = Mock(pool='test_pool', pool_slots=1)
- self.assertTrue(PoolSlotsAvailableDep().is_met(ti=ti))
+ assert PoolSlotsAvailableDep().is_met(ti=ti)
@patch('airflow.models.Pool.open_slots', return_value=0)
# pylint: disable=unused-argument
def test_running_pooled_task_pass(self, mock_open_slots):
for state in EXECUTION_STATES:
ti = Mock(pool='test_pool', state=state, pool_slots=1)
- self.assertTrue(PoolSlotsAvailableDep().is_met(ti=ti))
+ assert PoolSlotsAvailableDep().is_met(ti=ti)
def test_task_with_nonexistent_pool(self):
ti = Mock(pool='nonexistent_pool', pool_slots=1)
- self.assertFalse(PoolSlotsAvailableDep().is_met(ti=ti))
+ assert not PoolSlotsAvailableDep().is_met(ti=ti)
diff --git a/tests/ti_deps/deps/test_prev_dagrun_dep.py b/tests/ti_deps/deps/test_prev_dagrun_dep.py
index 4a05d77ea72c1..29137f53f0b13 100644
--- a/tests/ti_deps/deps/test_prev_dagrun_dep.py
+++ b/tests/ti_deps/deps/test_prev_dagrun_dep.py
@@ -48,7 +48,7 @@ def test_not_depends_on_past(self):
ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 3))
dep_context = DepContext(ignore_depends_on_past=False)
- self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context))
+ assert PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)
def test_context_ignore_depends_on_past(self):
"""
@@ -67,7 +67,7 @@ def test_context_ignore_depends_on_past(self):
ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 3))
dep_context = DepContext(ignore_depends_on_past=True)
- self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context))
+ assert PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)
def test_first_task_run(self):
"""
@@ -80,7 +80,7 @@ def test_first_task_run(self):
ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 1))
dep_context = DepContext(ignore_depends_on_past=False)
- self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context))
+ assert PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)
def test_prev_ti_bad_state(self):
"""
@@ -93,7 +93,7 @@ def test_prev_ti_bad_state(self):
ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 2))
dep_context = DepContext(ignore_depends_on_past=False)
- self.assertFalse(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context))
+ assert not PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)
def test_failed_wait_for_downstream(self):
"""
@@ -106,7 +106,7 @@ def test_failed_wait_for_downstream(self):
ti = Mock(task=task, previous_ti=prev_ti, execution_date=datetime(2016, 1, 2))
dep_context = DepContext(ignore_depends_on_past=False)
- self.assertFalse(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context))
+ assert not PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)
def test_all_met(self):
"""
@@ -117,4 +117,4 @@ def test_all_met(self):
ti = Mock(task=task, execution_date=datetime(2016, 1, 2), **{'get_previous_ti.return_value': prev_ti})
dep_context = DepContext(ignore_depends_on_past=False)
- self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context))
+ assert PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)
diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
index f9bbb9f282b3f..bcef997b87e68 100644
--- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
+++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
@@ -49,17 +49,17 @@ def _get_task_reschedule(self, reschedule_date):
def test_should_pass_if_ignore_in_reschedule_period_is_set(self):
ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
dep_context = DepContext(ignore_in_reschedule_period=True)
- self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti, dep_context=dep_context))
+ assert ReadyToRescheduleDep().is_met(ti=ti, dep_context=dep_context)
def test_should_pass_if_not_in_none_state(self):
ti = self._get_task_instance(State.UP_FOR_RETRY)
- self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti))
+ assert ReadyToRescheduleDep().is_met(ti=ti)
@patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
def test_should_pass_if_no_reschedule_record_exists(self, mock_query_for_task_instance):
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = []
ti = self._get_task_instance(State.NONE)
- self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti))
+ assert ReadyToRescheduleDep().is_met(ti=ti)
@patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
def test_should_pass_after_reschedule_date_one(self, mock_query_for_task_instance):
@@ -67,7 +67,7 @@ def test_should_pass_after_reschedule_date_one(self, mock_query_for_task_instanc
self._get_task_reschedule(utcnow() - timedelta(minutes=1))
)
ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
- self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti))
+ assert ReadyToRescheduleDep().is_met(ti=ti)
@patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
def test_should_pass_after_reschedule_date_multiple(self, mock_query_for_task_instance):
@@ -77,7 +77,7 @@ def test_should_pass_after_reschedule_date_multiple(self, mock_query_for_task_in
self._get_task_reschedule(utcnow() - timedelta(minutes=1)),
][-1]
ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
- self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti))
+ assert ReadyToRescheduleDep().is_met(ti=ti)
@patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
def test_should_fail_before_reschedule_date_one(self, mock_query_for_task_instance):
@@ -86,7 +86,7 @@ def test_should_fail_before_reschedule_date_one(self, mock_query_for_task_instan
)
ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
- self.assertFalse(ReadyToRescheduleDep().is_met(ti=ti))
+ assert not ReadyToRescheduleDep().is_met(ti=ti)
@patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
def test_should_fail_before_reschedule_date_multiple(self, mock_query_for_task_instance):
@@ -96,4 +96,4 @@ def test_should_fail_before_reschedule_date_multiple(self, mock_query_for_task_i
self._get_task_reschedule(utcnow() + timedelta(minutes=1)),
][-1]
ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
- self.assertFalse(ReadyToRescheduleDep().is_met(ti=ti))
+ assert not ReadyToRescheduleDep().is_met(ti=ti)
diff --git a/tests/ti_deps/deps/test_runnable_exec_date_dep.py b/tests/ti_deps/deps/test_runnable_exec_date_dep.py
index d285c8fbb4794..fb5354e8d92b6 100644
--- a/tests/ti_deps/deps/test_runnable_exec_date_dep.py
+++ b/tests/ti_deps/deps/test_runnable_exec_date_dep.py
@@ -84,7 +84,7 @@ def test_exec_date_after_end_date(self):
op1 = DummyOperator(task_id='op1')
ti = TaskInstance(task=op1, execution_date=datetime(2016, 11, 2))
- self.assertFalse(RunnableExecDateDep().is_met(ti=ti))
+ assert not RunnableExecDateDep().is_met(ti=ti)
def test_exec_date_after_task_end_date(self):
"""
@@ -96,7 +96,7 @@ def test_exec_date_after_task_end_date(self):
task_end_date=datetime(2016, 1, 1),
execution_date=datetime(2016, 1, 2),
)
- self.assertFalse(RunnableExecDateDep().is_met(ti=ti))
+ assert not RunnableExecDateDep().is_met(ti=ti)
def test_exec_date_after_dag_end_date(self):
"""
@@ -108,7 +108,7 @@ def test_exec_date_after_dag_end_date(self):
task_end_date=datetime(2016, 1, 3),
execution_date=datetime(2016, 1, 2),
)
- self.assertFalse(RunnableExecDateDep().is_met(ti=ti))
+ assert not RunnableExecDateDep().is_met(ti=ti)
def test_all_deps_met(self):
"""
@@ -119,4 +119,4 @@ def test_all_deps_met(self):
task_end_date=datetime(2016, 1, 2),
execution_date=datetime(2016, 1, 1),
)
- self.assertTrue(RunnableExecDateDep().is_met(ti=ti))
+ assert RunnableExecDateDep().is_met(ti=ti)
diff --git a/tests/ti_deps/deps/test_task_concurrency.py b/tests/ti_deps/deps/test_task_concurrency.py
index 9bdd1fab1c971..d91f9f656a45b 100644
--- a/tests/ti_deps/deps/test_task_concurrency.py
+++ b/tests/ti_deps/deps/test_task_concurrency.py
@@ -34,20 +34,20 @@ def test_not_task_concurrency(self):
task = self._get_task(start_date=datetime(2016, 1, 1))
dep_context = DepContext()
ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
- self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context))
+ assert TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)
def test_not_reached_concurrency(self):
task = self._get_task(start_date=datetime(2016, 1, 1), task_concurrency=1)
dep_context = DepContext()
ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
ti.get_num_running_task_instances = lambda x: 0
- self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context))
+ assert TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)
def test_reached_concurrency(self):
task = self._get_task(start_date=datetime(2016, 1, 1), task_concurrency=2)
dep_context = DepContext()
ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
ti.get_num_running_task_instances = lambda x: 1
- self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context))
+ assert TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)
ti.get_num_running_task_instances = lambda x: 2
- self.assertFalse(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context))
+ assert not TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)
diff --git a/tests/ti_deps/deps/test_task_not_running_dep.py b/tests/ti_deps/deps/test_task_not_running_dep.py
index 353db3fafb8db..b3e558e2fac4c 100644
--- a/tests/ti_deps/deps/test_task_not_running_dep.py
+++ b/tests/ti_deps/deps/test_task_not_running_dep.py
@@ -27,8 +27,8 @@
class TestTaskNotRunningDep(unittest.TestCase):
def test_not_running_state(self):
ti = Mock(state=State.QUEUED, end_date=datetime(2016, 1, 1))
- self.assertTrue(TaskNotRunningDep().is_met(ti=ti))
+ assert TaskNotRunningDep().is_met(ti=ti)
def test_running_state(self):
ti = Mock(state=State.RUNNING, end_date=datetime(2016, 1, 1))
- self.assertFalse(TaskNotRunningDep().is_met(ti=ti))
+ assert not TaskNotRunningDep().is_met(ti=ti)
diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py
index a83863b306516..8bebaf79d2488 100644
--- a/tests/ti_deps/deps/test_trigger_rule_dep.py
+++ b/tests/ti_deps/deps/test_trigger_rule_dep.py
@@ -45,14 +45,14 @@ def test_no_upstream_tasks(self):
If the TI has no upstream TIs then there is nothing to check and the dep is passed
"""
ti = self._get_task_instance(TriggerRule.ALL_DONE, State.UP_FOR_RETRY)
- self.assertTrue(TriggerRuleDep().is_met(ti=ti))
+ assert TriggerRuleDep().is_met(ti=ti)
def test_dummy_tr(self):
"""
The dummy trigger rule should always pass this dep
"""
ti = self._get_task_instance(TriggerRule.DUMMY, State.UP_FOR_RETRY)
- self.assertTrue(TriggerRuleDep().is_met(ti=ti))
+ assert TriggerRuleDep().is_met(ti=ti)
def test_one_success_tr_success(self):
"""
@@ -71,7 +71,7 @@ def test_one_success_tr_success(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 0)
+ assert len(dep_statuses) == 0
def test_one_success_tr_failure(self):
"""
@@ -90,8 +90,8 @@ def test_one_success_tr_failure(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 1)
- self.assertFalse(dep_statuses[0].passed)
+ assert len(dep_statuses) == 1
+ assert not dep_statuses[0].passed
def test_one_failure_tr_failure(self):
"""
@@ -110,8 +110,8 @@ def test_one_failure_tr_failure(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 1)
- self.assertFalse(dep_statuses[0].passed)
+ assert len(dep_statuses) == 1
+ assert not dep_statuses[0].passed
def test_one_failure_tr_success(self):
"""
@@ -130,7 +130,7 @@ def test_one_failure_tr_success(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 0)
+ assert len(dep_statuses) == 0
dep_statuses = tuple(
TriggerRuleDep()._evaluate_trigger_rule(
@@ -144,7 +144,7 @@ def test_one_failure_tr_success(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 0)
+ assert len(dep_statuses) == 0
def test_all_success_tr_success(self):
"""
@@ -163,7 +163,7 @@ def test_all_success_tr_success(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 0)
+ assert len(dep_statuses) == 0
def test_all_success_tr_failure(self):
"""
@@ -184,8 +184,8 @@ def test_all_success_tr_failure(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 1)
- self.assertFalse(dep_statuses[0].passed)
+ assert len(dep_statuses) == 1
+ assert not dep_statuses[0].passed
def test_all_success_tr_skip(self):
"""
@@ -206,8 +206,8 @@ def test_all_success_tr_skip(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 1)
- self.assertFalse(dep_statuses[0].passed)
+ assert len(dep_statuses) == 1
+ assert not dep_statuses[0].passed
def test_all_success_tr_skip_flag_upstream(self):
"""
@@ -229,9 +229,9 @@ def test_all_success_tr_skip_flag_upstream(self):
session=Mock(),
)
)
- self.assertEqual(len(dep_statuses), 1)
- self.assertFalse(dep_statuses[0].passed)
- self.assertEqual(ti.state, State.SKIPPED)
+ assert len(dep_statuses) == 1
+ assert not dep_statuses[0].passed
+ assert ti.state == State.SKIPPED
def test_none_failed_tr_success(self):
"""
@@ -252,7 +252,7 @@ def test_none_failed_tr_success(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 0)
+ assert len(dep_statuses) == 0
def test_none_failed_tr_skipped(self):
"""
@@ -273,8 +273,8 @@ def test_none_failed_tr_skipped(self):
session=Mock(),
)
)
- self.assertEqual(len(dep_statuses), 0)
- self.assertEqual(ti.state, State.NONE)
+ assert len(dep_statuses) == 0
+ assert ti.state == State.NONE
def test_none_failed_tr_failure(self):
"""
@@ -295,8 +295,8 @@ def test_none_failed_tr_failure(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 1)
- self.assertFalse(dep_statuses[0].passed)
+ assert len(dep_statuses) == 1
+ assert not dep_statuses[0].passed
def test_none_failed_or_skipped_tr_success(self):
"""
@@ -317,7 +317,7 @@ def test_none_failed_or_skipped_tr_success(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 0)
+ assert len(dep_statuses) == 0
def test_none_failed_or_skipped_tr_skipped(self):
"""
@@ -338,8 +338,8 @@ def test_none_failed_or_skipped_tr_skipped(self):
session=Mock(),
)
)
- self.assertEqual(len(dep_statuses), 0)
- self.assertEqual(ti.state, State.SKIPPED)
+ assert len(dep_statuses) == 0
+ assert ti.state == State.SKIPPED
def test_none_failed_or_skipped_tr_failure(self):
"""
@@ -361,8 +361,8 @@ def test_none_failed_or_skipped_tr_failure(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 1)
- self.assertFalse(dep_statuses[0].passed)
+ assert len(dep_statuses) == 1
+ assert not dep_statuses[0].passed
def test_all_failed_tr_success(self):
"""
@@ -383,7 +383,7 @@ def test_all_failed_tr_success(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 0)
+ assert len(dep_statuses) == 0
def test_all_failed_tr_failure(self):
"""
@@ -404,8 +404,8 @@ def test_all_failed_tr_failure(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 1)
- self.assertFalse(dep_statuses[0].passed)
+ assert len(dep_statuses) == 1
+ assert not dep_statuses[0].passed
def test_all_done_tr_success(self):
"""
@@ -426,7 +426,7 @@ def test_all_done_tr_success(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 0)
+ assert len(dep_statuses) == 0
def test_all_done_tr_failure(self):
"""
@@ -447,8 +447,8 @@ def test_all_done_tr_failure(self):
session="Fake Session",
)
)
- self.assertEqual(len(dep_statuses), 1)
- self.assertFalse(dep_statuses[0].passed)
+ assert len(dep_statuses) == 1
+ assert not dep_statuses[0].passed
def test_none_skipped_tr_success(self):
"""
@@ -471,7 +471,7 @@ def test_none_skipped_tr_success(self):
session=session,
)
)
- self.assertEqual(len(dep_statuses), 0)
+ assert len(dep_statuses) == 0
# with `flag_upstream_failed` set to True
dep_statuses = tuple(
@@ -486,7 +486,7 @@ def test_none_skipped_tr_success(self):
session=session,
)
)
- self.assertEqual(len(dep_statuses), 0)
+ assert len(dep_statuses) == 0
def test_none_skipped_tr_failure(self):
"""
@@ -509,8 +509,8 @@ def test_none_skipped_tr_failure(self):
session=session,
)
)
- self.assertEqual(len(dep_statuses), 1)
- self.assertFalse(dep_statuses[0].passed)
+ assert len(dep_statuses) == 1
+ assert not dep_statuses[0].passed
# with `flag_upstream_failed` set to True
dep_statuses = tuple(
@@ -525,8 +525,8 @@ def test_none_skipped_tr_failure(self):
session=session,
)
)
- self.assertEqual(len(dep_statuses), 1)
- self.assertFalse(dep_statuses[0].passed)
+ assert len(dep_statuses) == 1
+ assert not dep_statuses[0].passed
# Fail until all upstream tasks have completed execution
dep_statuses = tuple(
@@ -541,8 +541,8 @@ def test_none_skipped_tr_failure(self):
session=session,
)
)
- self.assertEqual(len(dep_statuses), 1)
- self.assertFalse(dep_statuses[0].passed)
+ assert len(dep_statuses) == 1
+ assert not dep_statuses[0].passed
def test_unknown_tr(self):
"""
@@ -563,8 +563,8 @@ def test_unknown_tr(self):
)
)
- self.assertEqual(len(dep_statuses), 1)
- self.assertFalse(dep_statuses[0].passed)
+ assert len(dep_statuses) == 1
+ assert not dep_statuses[0].passed
def test_get_states_count_upstream_ti(self):
"""
@@ -610,16 +610,10 @@ def test_get_states_count_upstream_ti(self):
# check handling with cases that tasks are triggered from backfill with no finished tasks
finished_tasks = DepContext().ensure_finished_tasks(ti_op2.task.dag, ti_op2.execution_date, session)
- self.assertEqual(
- get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op2), (1, 0, 0, 0, 1)
- )
+ assert get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op2) == (1, 0, 0, 0, 1)
finished_tasks = dr.get_task_instances(state=State.finished, session=session)
- self.assertEqual(
- get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4), (1, 0, 1, 0, 2)
- )
- self.assertEqual(
- get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op5), (2, 0, 1, 0, 3)
- )
+ assert get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4) == (1, 0, 1, 0, 2)
+ assert get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op5) == (2, 0, 1, 0, 3)
dr.update_state()
- self.assertEqual(State.SUCCESS, dr.state)
+ assert State.SUCCESS == dr.state
diff --git a/tests/ti_deps/deps/test_valid_state_dep.py b/tests/ti_deps/deps/test_valid_state_dep.py
index 7e6ee7fcf2fac..42118365ba6bb 100644
--- a/tests/ti_deps/deps/test_valid_state_dep.py
+++ b/tests/ti_deps/deps/test_valid_state_dep.py
@@ -20,6 +20,8 @@
from datetime import datetime
from unittest.mock import Mock
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.ti_deps.deps.valid_state_dep import ValidStateDep
from airflow.utils.state import State
@@ -31,19 +33,19 @@ def test_valid_state(self):
Valid state should pass this dep
"""
ti = Mock(state=State.QUEUED, end_date=datetime(2016, 1, 1))
- self.assertTrue(ValidStateDep({State.QUEUED}).is_met(ti=ti))
+ assert ValidStateDep({State.QUEUED}).is_met(ti=ti)
def test_invalid_state(self):
"""
Invalid state should fail this dep
"""
ti = Mock(state=State.SUCCESS, end_date=datetime(2016, 1, 1))
- self.assertFalse(ValidStateDep({State.FAILED}).is_met(ti=ti))
+ assert not ValidStateDep({State.FAILED}).is_met(ti=ti)
def test_no_valid_states(self):
"""
If there are no valid states the dependency should throw
"""
ti = Mock(state=State.SUCCESS, end_date=datetime(2016, 1, 1))
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
ValidStateDep({}).is_met(ti=ti)
diff --git a/tests/utils/log/test_file_processor_handler.py b/tests/utils/log/test_file_processor_handler.py
index 20115059bc7a8..438339c8cf7f8 100644
--- a/tests/utils/log/test_file_processor_handler.py
+++ b/tests/utils/log/test_file_processor_handler.py
@@ -41,11 +41,11 @@ def test_non_template(self):
handler.dag_dir = self.dag_dir
path = os.path.join(self.base_log_folder, "latest")
- self.assertTrue(os.path.islink(path))
- self.assertEqual(os.path.basename(os.readlink(path)), date)
+ assert os.path.islink(path)
+ assert os.path.basename(os.readlink(path)) == date
handler.set_context(filename=os.path.join(self.dag_dir, "logfile"))
- self.assertTrue(os.path.exists(os.path.join(path, "logfile")))
+ assert os.path.exists(os.path.join(path, "logfile"))
def test_template(self):
date = timezone.utcnow().strftime("%Y-%m-%d")
@@ -55,11 +55,11 @@ def test_template(self):
handler.dag_dir = self.dag_dir
path = os.path.join(self.base_log_folder, "latest")
- self.assertTrue(os.path.islink(path))
- self.assertEqual(os.path.basename(os.readlink(path)), date)
+ assert os.path.islink(path)
+ assert os.path.basename(os.readlink(path)) == date
handler.set_context(filename=os.path.join(self.dag_dir, "logfile"))
- self.assertTrue(os.path.exists(os.path.join(path, "logfile.log")))
+ assert os.path.exists(os.path.join(path, "logfile.log"))
def test_symlink_latest_log_directory(self):
handler = FileProcessorHandler(base_log_folder=self.base_log_folder, filename_template=self.filename)
@@ -80,15 +80,15 @@ def test_symlink_latest_log_directory(self):
with freeze_time(date1):
handler.set_context(filename=os.path.join(self.dag_dir, "log1"))
- self.assertTrue(os.path.islink(link))
- self.assertEqual(os.path.basename(os.readlink(link)), date1)
- self.assertTrue(os.path.exists(os.path.join(link, "log1")))
+ assert os.path.islink(link)
+ assert os.path.basename(os.readlink(link)) == date1
+ assert os.path.exists(os.path.join(link, "log1"))
with freeze_time(date2):
handler.set_context(filename=os.path.join(self.dag_dir, "log2"))
- self.assertTrue(os.path.islink(link))
- self.assertEqual(os.path.basename(os.readlink(link)), date2)
- self.assertTrue(os.path.exists(os.path.join(link, "log2")))
+ assert os.path.islink(link)
+ assert os.path.basename(os.readlink(link)) == date2
+ assert os.path.exists(os.path.join(link, "log2"))
def test_symlink_latest_log_directory_exists(self):
handler = FileProcessorHandler(base_log_folder=self.base_log_folder, filename_template=self.filename)
diff --git a/tests/utils/log/test_json_formatter.py b/tests/utils/log/test_json_formatter.py
index ba827fb25e87e..b25d11b1a29f4 100644
--- a/tests/utils/log/test_json_formatter.py
+++ b/tests/utils/log/test_json_formatter.py
@@ -36,7 +36,7 @@ def test_json_formatter_is_not_none(self):
JSONFormatter instance should return not none
"""
json_fmt = JSONFormatter()
- self.assertIsNotNone(json_fmt)
+ assert json_fmt is not None
def test_uses_time(self):
"""
@@ -44,8 +44,8 @@ def test_uses_time(self):
"""
json_fmt_asctime = JSONFormatter(json_fields=["asctime", "label"])
json_fmt_no_asctime = JSONFormatter(json_fields=["label"])
- self.assertTrue(json_fmt_asctime.usesTime())
- self.assertFalse(json_fmt_no_asctime.usesTime())
+ assert json_fmt_asctime.usesTime()
+ assert not json_fmt_no_asctime.usesTime()
def test_format(self):
"""
@@ -53,7 +53,7 @@ def test_format(self):
"""
log_record = makeLogRecord({"label": "value"})
json_fmt = JSONFormatter(json_fields=["label"])
- self.assertEqual(json_fmt.format(log_record), '{"label": "value"}')
+ assert json_fmt.format(log_record) == '{"label": "value"}'
def test_format_with_extras(self):
"""
@@ -62,6 +62,4 @@ def test_format_with_extras(self):
log_record = makeLogRecord({"label": "value"})
json_fmt = JSONFormatter(json_fields=["label"], extras={'pod_extra': 'useful_message'})
# compare as a dicts to not fail on sorting errors
- self.assertDictEqual(
- json.loads(json_fmt.format(log_record)), {"label": "value", "pod_extra": "useful_message"}
- )
+ assert json.loads(json_fmt.format(log_record)) == {"label": "value", "pod_extra": "useful_message"}
diff --git a/tests/utils/log/test_log_reader.py b/tests/utils/log/test_log_reader.py
index 940b13fe6b838..b25ff674e3754 100644
--- a/tests/utils/log/test_log_reader.py
+++ b/tests/utils/log/test_log_reader.py
@@ -103,88 +103,76 @@ def test_test_read_log_chunks_should_read_one_try(self):
task_log_reader = TaskLogReader()
logs, metadatas = task_log_reader.read_log_chunks(ti=self.ti, try_number=1, metadata={})
- self.assertEqual(
+ assert [
+ (
+ '',
+ f"*** Reading local file: "
+ f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
+ f"try_number=1.\n",
+ )
+ ] == logs[0]
+ assert {"end_of_log": True} == metadatas
+
+ def test_test_read_log_chunks_should_read_all_files(self):
+ task_log_reader = TaskLogReader()
+ logs, metadatas = task_log_reader.read_log_chunks(ti=self.ti, try_number=None, metadata={})
+
+ assert [
[
(
'',
- f"*** Reading local file: "
+ "*** Reading local file: "
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
- f"try_number=1.\n",
+ "try_number=1.\n",
+ )
+ ],
+ [
+ (
+ '',
+ f"*** Reading local file: "
+ f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/2.log\n"
+ f"try_number=2.\n",
)
],
- logs[0],
- )
- self.assertEqual({"end_of_log": True}, metadatas)
-
- def test_test_read_log_chunks_should_read_all_files(self):
- task_log_reader = TaskLogReader()
- logs, metadatas = task_log_reader.read_log_chunks(ti=self.ti, try_number=None, metadata={})
-
- self.assertEqual(
[
- [
- (
- '',
- "*** Reading local file: "
- f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
- "try_number=1.\n",
- )
- ],
- [
- (
- '',
- f"*** Reading local file: "
- f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/2.log\n"
- f"try_number=2.\n",
- )
- ],
- [
- (
- '',
- f"*** Reading local file: "
- f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log\n"
- f"try_number=3.\n",
- )
- ],
+ (
+ '',
+ f"*** Reading local file: "
+ f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log\n"
+ f"try_number=3.\n",
+ )
],
- logs,
- )
- self.assertEqual({"end_of_log": True}, metadatas)
+ ] == logs
+ assert {"end_of_log": True} == metadatas
def test_test_test_read_log_stream_should_read_one_try(self):
task_log_reader = TaskLogReader()
stream = task_log_reader.read_log_stream(ti=self.ti, try_number=1, metadata={})
- self.assertEqual(
- [
- "\n*** Reading local file: "
- f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
- "try_number=1.\n"
- "\n"
- ],
- list(stream),
- )
+ assert [
+ "\n*** Reading local file: "
+ f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
+ "try_number=1.\n"
+ "\n"
+ ] == list(stream)
def test_test_test_read_log_stream_should_read_all_logs(self):
task_log_reader = TaskLogReader()
stream = task_log_reader.read_log_stream(ti=self.ti, try_number=None, metadata={})
- self.assertEqual(
- [
- "\n*** Reading local file: "
- f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
- "try_number=1.\n"
- "\n",
- "\n*** Reading local file: "
- f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/2.log\n"
- "try_number=2.\n"
- "\n",
- "\n*** Reading local file: "
- f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log\n"
- "try_number=3.\n"
- "\n",
- ],
- list(stream),
- )
+ assert [
+ "\n*** Reading local file: "
+ f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
+ "try_number=1.\n"
+ "\n",
+ "\n*** Reading local file: "
+ f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/2.log\n"
+ "try_number=2.\n"
+ "\n",
+ "\n*** Reading local file: "
+ f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log\n"
+ "try_number=3.\n"
+ "\n",
+ ] == list(stream)
@mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read")
def test_read_log_stream_should_support_multiple_chunks(self, mock_read):
@@ -196,7 +184,7 @@ def test_read_log_stream_should_support_multiple_chunks(self, mock_read):
task_log_reader = TaskLogReader()
log_stream = task_log_reader.read_log_stream(ti=self.ti, try_number=1, metadata={})
- self.assertEqual(["\n1st line\n", "\n2nd line\n", "\n3rd line\n"], list(log_stream))
+ assert ["\n1st line\n", "\n2nd line\n", "\n3rd line\n"] == list(log_stream)
mock_read.assert_has_calls(
[
@@ -217,7 +205,7 @@ def test_read_log_stream_should_read_each_try_in_turn(self, mock_read):
task_log_reader = TaskLogReader()
log_stream = task_log_reader.read_log_stream(ti=self.ti, try_number=None, metadata={})
- self.assertEqual(['\ntry_number=1.\n', '\ntry_number=2.\n', '\ntry_number=3.\n'], list(log_stream))
+ assert ['\ntry_number=1.\n', '\ntry_number=2.\n', '\ntry_number=3.\n'] == list(log_stream)
mock_read.assert_has_calls(
[
diff --git a/tests/utils/test_cli_util.py b/tests/utils/test_cli_util.py
index 5a64c457b9837..c567f44d1eab6 100644
--- a/tests/utils/test_cli_util.py
+++ b/tests/utils/test_cli_util.py
@@ -25,6 +25,7 @@
from datetime import datetime
from unittest import mock
+import pytest
from parameterized import parameterized
from airflow import settings
@@ -47,24 +48,24 @@ def test_metrics_build(self):
'execution_date': exec_date,
}
for k, v in expected.items():
- self.assertEqual(v, metrics.get(k))
+ assert v == metrics.get(k)
- self.assertTrue(metrics.get('start_datetime') <= datetime.utcnow())
- self.assertTrue(metrics.get('full_command'))
+ assert metrics.get('start_datetime') <= datetime.utcnow()
+ assert metrics.get('full_command')
log_dao = metrics.get('log')
- self.assertTrue(log_dao)
- self.assertEqual(log_dao.dag_id, metrics.get('dag_id'))
- self.assertEqual(log_dao.task_id, metrics.get('task_id'))
- self.assertEqual(log_dao.execution_date, metrics.get('execution_date'))
- self.assertEqual(log_dao.owner, metrics.get('user'))
+ assert log_dao
+ assert log_dao.dag_id == metrics.get('dag_id')
+ assert log_dao.task_id == metrics.get('task_id')
+ assert log_dao.execution_date == metrics.get('execution_date')
+ assert log_dao.owner == metrics.get('user')
def test_fail_function(self):
"""
Actual function is failing and fail needs to be propagated.
:return:
"""
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
fail_func(Namespace())
def test_success_function(self):
@@ -77,16 +78,16 @@ def test_success_function(self):
success_func(Namespace())
def test_process_subdir_path_with_placeholder(self):
- self.assertEqual(os.path.join(settings.DAGS_FOLDER, 'abc'), cli.process_subdir('DAGS_FOLDER/abc'))
+ assert os.path.join(settings.DAGS_FOLDER, 'abc') == cli.process_subdir('DAGS_FOLDER/abc')
def test_get_dags(self):
dags = cli.get_dags(None, "example_subdag_operator")
- self.assertEqual(len(dags), 1)
+ assert len(dags) == 1
dags = cli.get_dags(None, "subdag", True)
- self.assertGreater(len(dags), 1)
+ assert len(dags) > 1
- with self.assertRaises(AirflowException):
+ with pytest.raises(AirflowException):
cli.get_dags(None, "foobar", True)
@parameterized.expand(
@@ -123,30 +124,30 @@ def test_cli_create_user_supplied_password_is_masked(self, given_command, expect
with mock.patch.object(sys, "argv", args):
metrics = cli._build_metrics(args[1], namespace)
- self.assertTrue(metrics.get('start_datetime') <= datetime.utcnow())
+ assert metrics.get('start_datetime') <= datetime.utcnow()
log = metrics.get('log')
command = json.loads(log.extra).get('full_command') # type: str
# Replace single quotes to double quotes to avoid json decode error
command = json.loads(command.replace("'", '"'))
- self.assertEqual(command, expected_command)
+ assert command == expected_command
def test_setup_locations_relative_pid_path(self):
relative_pid_path = "fake.pid"
pid_full_path = os.path.join(os.getcwd(), relative_pid_path)
pid, _, _, _ = cli.setup_locations(process="fake_process", pid=relative_pid_path)
- self.assertEqual(pid, pid_full_path)
+ assert pid == pid_full_path
def test_setup_locations_absolute_pid_path(self):
abs_pid_path = os.path.join(os.getcwd(), "fake.pid")
pid, _, _, _ = cli.setup_locations(process="fake_process", pid=abs_pid_path)
- self.assertEqual(pid, abs_pid_path)
+ assert pid == abs_pid_path
def test_setup_locations_none_pid_path(self):
process_name = "fake_process"
default_pid_path = os.path.join(settings.AIRFLOW_HOME, f"airflow-{process_name}.pid")
pid, _, _, _ = cli.setup_locations(process=process_name)
- self.assertEqual(pid, default_pid_path)
+ assert pid == default_pid_path
@contextmanager
diff --git a/tests/utils/test_compression.py b/tests/utils/test_compression.py
index b288ca60fde64..cfa84f5f03fe6 100644
--- a/tests/utils/test_compression.py
+++ b/tests/utils/test_compression.py
@@ -25,6 +25,8 @@
import tempfile
import unittest
+import pytest
+
from airflow.utils import compression
@@ -76,22 +78,16 @@ def _get_fn(self, ext):
def test_uncompress_file(self):
# Testing txt file type
- self.assertRaisesRegex(
- NotImplementedError,
- "^Received .txt format. Only gz and bz2.*",
- compression.uncompress_file,
- **{'input_file_name': None, 'file_extension': '.txt', 'dest_dir': None},
- )
+ with pytest.raises(NotImplementedError, match="^Received .txt format. Only gz and bz2.*"):
+ compression.uncompress_file(
+ **{'input_file_name': None, 'file_extension': '.txt', 'dest_dir': None},
+ )
# Testing gz file type
fn_txt = self._get_fn('.txt')
fn_gz = self._get_fn('.gz')
txt_gz = compression.uncompress_file(fn_gz, '.gz', self.tmp_dir)
- self.assertTrue(
- filecmp.cmp(txt_gz, fn_txt, shallow=False), msg="Uncompressed file doest match original"
- )
+ assert filecmp.cmp(txt_gz, fn_txt, shallow=False), "Uncompressed file doest match original"
# Testing bz2 file type
fn_bz2 = self._get_fn('.bz2')
txt_bz2 = compression.uncompress_file(fn_bz2, '.bz2', self.tmp_dir)
- self.assertTrue(
- filecmp.cmp(txt_bz2, fn_txt, shallow=False), msg="Uncompressed file doest match original"
- )
+ assert filecmp.cmp(txt_bz2, fn_txt, shallow=False), "Uncompressed file doest match original"
diff --git a/tests/utils/test_dag_cycle.py b/tests/utils/test_dag_cycle.py
index 7602c05522c27..898985bbf29f7 100644
--- a/tests/utils/test_dag_cycle.py
+++ b/tests/utils/test_dag_cycle.py
@@ -17,6 +17,8 @@
import unittest
+import pytest
+
from airflow import DAG
from airflow.exceptions import AirflowDagCycleException
from airflow.operators.dummy import DummyOperator
@@ -29,7 +31,7 @@ def test_cycle_empty(self):
# test empty
dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
- self.assertFalse(_test_cycle(dag))
+ assert not _test_cycle(dag)
def test_cycle_single_task(self):
# test single task
@@ -38,7 +40,7 @@ def test_cycle_single_task(self):
with dag:
DummyOperator(task_id='A')
- self.assertFalse(_test_cycle(dag))
+ assert not _test_cycle(dag)
def test_semi_complex(self):
dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
@@ -75,7 +77,7 @@ def test_cycle_no_cycle(self):
op2.set_downstream(op4)
op5.set_downstream(op6)
- self.assertFalse(_test_cycle(dag))
+ assert not _test_cycle(dag)
def test_cycle_loop(self):
# test self loop
@@ -86,8 +88,8 @@ def test_cycle_loop(self):
op1 = DummyOperator(task_id='A')
op1.set_downstream(op1)
- with self.assertRaises(AirflowDagCycleException):
- self.assertFalse(_test_cycle(dag))
+ with pytest.raises(AirflowDagCycleException):
+ assert not _test_cycle(dag)
def test_cycle_downstream_loop(self):
# test downstream self loop
@@ -106,8 +108,8 @@ def test_cycle_downstream_loop(self):
op4.set_downstream(op5)
op5.set_downstream(op5)
- with self.assertRaises(AirflowDagCycleException):
- self.assertFalse(_test_cycle(dag))
+ with pytest.raises(AirflowDagCycleException):
+ assert not _test_cycle(dag)
def test_cycle_large_loop(self):
# large loop
@@ -124,8 +126,8 @@ def test_cycle_large_loop(self):
current = next_task
current.set_downstream(start)
- with self.assertRaises(AirflowDagCycleException):
- self.assertFalse(_test_cycle(dag))
+ with pytest.raises(AirflowDagCycleException):
+ assert not _test_cycle(dag)
def test_cycle_arbitrary_loop(self):
# test arbitrary loop
@@ -146,5 +148,5 @@ def test_cycle_arbitrary_loop(self):
op2.set_downstream(op5)
op5.set_downstream(op1)
- with self.assertRaises(AirflowDagCycleException):
- self.assertFalse(_test_cycle(dag))
+ with pytest.raises(AirflowDagCycleException):
+ assert not _test_cycle(dag)
diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py
index ad8ef5a5c95b4..dc082104a9165 100644
--- a/tests/utils/test_dag_processing.py
+++ b/tests/utils/test_dag_processing.py
@@ -201,7 +201,7 @@ def test_set_file_paths_when_processor_file_path_not_in_new_file_paths(self):
manager._file_stats['missing_file.txt'] = DagFileStat(0, 0, None, None, 0)
manager.set_file_paths(['abc.txt'])
- self.assertDictEqual(manager._processors, {})
+ assert manager._processors == {}
def test_set_file_paths_when_processor_file_path_is_in_new_file_paths(self):
manager = DagFileProcessorManager(
@@ -222,7 +222,7 @@ def test_set_file_paths_when_processor_file_path_is_in_new_file_paths(self):
manager._processors['abc.txt'] = mock_processor
manager.set_file_paths(['abc.txt'])
- self.assertDictEqual(manager._processors, {'abc.txt': mock_processor})
+ assert manager._processors == {'abc.txt': mock_processor}
def test_find_zombies(self):
manager = DagFileProcessorManager(
@@ -259,14 +259,14 @@ def test_find_zombies(self):
)
manager._find_zombies() # pylint: disable=no-value-for-parameter
requests = manager._callback_to_execute[dag.full_filepath]
- self.assertEqual(1, len(requests))
- self.assertEqual(requests[0].full_filepath, dag.full_filepath)
- self.assertEqual(requests[0].msg, "Detected as zombie")
- self.assertEqual(requests[0].is_failure_callback, True)
- self.assertIsInstance(requests[0].simple_task_instance, SimpleTaskInstance)
- self.assertEqual(ti.dag_id, requests[0].simple_task_instance.dag_id)
- self.assertEqual(ti.task_id, requests[0].simple_task_instance.task_id)
- self.assertEqual(ti.execution_date, requests[0].simple_task_instance.execution_date)
+ assert 1 == len(requests)
+ assert requests[0].full_filepath == dag.full_filepath
+ assert requests[0].msg == "Detected as zombie"
+ assert requests[0].is_failure_callback is True
+ assert isinstance(requests[0].simple_task_instance, SimpleTaskInstance)
+ assert ti.dag_id == requests[0].simple_task_instance.dag_id
+ assert ti.task_id == requests[0].simple_task_instance.task_id
+ assert ti.execution_date == requests[0].simple_task_instance.execution_date
session.query(TI).delete()
session.query(LJ).delete()
@@ -486,7 +486,7 @@ class path, thus when reloading logging module the airflow.processor_manager
# Since we are reloading logging config not creating this file,
# we should expect it to be nonexistent.
- self.assertFalse(os.path.isfile(log_file_loc))
+ assert not os.path.isfile(log_file_loc)
@conf_vars({('core', 'load_examples'): 'False'})
def test_parse_once(self):
@@ -536,7 +536,7 @@ def test_launch_process(self):
processor_agent._process.join()
- self.assertTrue(os.path.isfile(log_file_loc))
+ assert os.path.isfile(log_file_loc)
class TestCorrectMaybeZipped(unittest.TestCase):
@@ -547,7 +547,7 @@ def test_correct_maybe_zipped_normal_file(self, mocked_is_zipfile):
dag_folder = correct_maybe_zipped(path)
- self.assertEqual(dag_folder, path)
+ assert dag_folder == path
@mock.patch("zipfile.is_zipfile")
def test_correct_maybe_zipped_normal_file_with_zip_in_name(self, mocked_is_zipfile):
@@ -556,7 +556,7 @@ def test_correct_maybe_zipped_normal_file_with_zip_in_name(self, mocked_is_zipfi
dag_folder = correct_maybe_zipped(path)
- self.assertEqual(dag_folder, path)
+ assert dag_folder == path
@mock.patch("zipfile.is_zipfile")
def test_correct_maybe_zipped_archive(self, mocked_is_zipfile):
@@ -567,9 +567,9 @@ def test_correct_maybe_zipped_archive(self, mocked_is_zipfile):
assert mocked_is_zipfile.call_count == 1
(args, kwargs) = mocked_is_zipfile.call_args_list[0]
- self.assertEqual('/path/to/archive.zip', args[0])
+ assert '/path/to/archive.zip' == args[0]
- self.assertEqual(dag_folder, '/path/to/archive.zip')
+ assert dag_folder == '/path/to/archive.zip'
class TestOpenMaybeZipped(unittest.TestCase):
diff --git a/tests/utils/test_dates.py b/tests/utils/test_dates.py
index c06022e0a4615..6eb3b115cf7da 100644
--- a/tests/utils/test_dates.py
+++ b/tests/utils/test_dates.py
@@ -20,6 +20,7 @@
from datetime import datetime, timedelta
import pendulum
+import pytest
from dateutil.relativedelta import relativedelta
from pytest import approx
@@ -31,27 +32,25 @@ def test_days_ago(self):
today = pendulum.today()
today_midnight = pendulum.instance(datetime.fromordinal(today.date().toordinal()))
- self.assertEqual(dates.days_ago(0), today_midnight)
- self.assertEqual(dates.days_ago(100), today_midnight + timedelta(days=-100))
+ assert dates.days_ago(0) == today_midnight
+ assert dates.days_ago(100) == today_midnight + timedelta(days=-100)
- self.assertEqual(dates.days_ago(0, hour=3), today_midnight + timedelta(hours=3))
- self.assertEqual(dates.days_ago(0, minute=3), today_midnight + timedelta(minutes=3))
- self.assertEqual(dates.days_ago(0, second=3), today_midnight + timedelta(seconds=3))
- self.assertEqual(dates.days_ago(0, microsecond=3), today_midnight + timedelta(microseconds=3))
+ assert dates.days_ago(0, hour=3) == today_midnight + timedelta(hours=3)
+ assert dates.days_ago(0, minute=3) == today_midnight + timedelta(minutes=3)
+ assert dates.days_ago(0, second=3) == today_midnight + timedelta(seconds=3)
+ assert dates.days_ago(0, microsecond=3) == today_midnight + timedelta(microseconds=3)
def test_parse_execution_date(self):
execution_date_str_wo_ms = '2017-11-02 00:00:00'
execution_date_str_w_ms = '2017-11-05 16:18:30.989729'
bad_execution_date_str = '2017-11-06TXX:00:00Z'
- self.assertEqual(
- timezone.datetime(2017, 11, 2, 0, 0, 0), dates.parse_execution_date(execution_date_str_wo_ms)
+ assert timezone.datetime(2017, 11, 2, 0, 0, 0) == dates.parse_execution_date(execution_date_str_wo_ms)
+ assert timezone.datetime(2017, 11, 5, 16, 18, 30, 989729) == dates.parse_execution_date(
+ execution_date_str_w_ms
)
- self.assertEqual(
- timezone.datetime(2017, 11, 5, 16, 18, 30, 989729),
- dates.parse_execution_date(execution_date_str_w_ms),
- )
- self.assertRaises(ValueError, dates.parse_execution_date, bad_execution_date_str)
+ with pytest.raises(ValueError):
+ dates.parse_execution_date(bad_execution_date_str)
def test_round_time(self):
@@ -107,41 +106,41 @@ def test_scale_time_units(self):
class TestUtilsDatesDateRange(unittest.TestCase):
def test_no_delta(self):
- self.assertEqual(dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3)), [])
+ assert dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3)) == []
def test_end_date_before_start_date(self):
- with self.assertRaisesRegex(Exception, "Wait. start_date needs to be before end_date"):
+ with pytest.raises(Exception, match="Wait. start_date needs to be before end_date"):
dates.date_range(datetime(2016, 2, 1), datetime(2016, 1, 1), delta=timedelta(seconds=1))
def test_both_end_date_and_num_given(self):
- with self.assertRaisesRegex(Exception, "Wait. Either specify end_date OR num"):
+ with pytest.raises(Exception, match="Wait. Either specify end_date OR num"):
dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), num=2, delta=timedelta(seconds=1))
def test_invalid_delta(self):
exception_msg = "Wait. delta must be either datetime.timedelta or cron expression as str"
- with self.assertRaisesRegex(Exception, exception_msg):
+ with pytest.raises(Exception, match=exception_msg):
dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta=1)
def test_positive_num_given(self):
for num in range(1, 10):
result = dates.date_range(datetime(2016, 1, 1), num=num, delta=timedelta(1))
- self.assertEqual(len(result), num)
+ assert len(result) == num
for i in range(num):
- self.assertTrue(timezone.is_localized(result[i]))
+ assert timezone.is_localized(result[i])
def test_negative_num_given(self):
for num in range(-1, -5, -10):
result = dates.date_range(datetime(2016, 1, 1), num=num, delta=timedelta(1))
- self.assertEqual(len(result), -num)
+ assert len(result) == -num
for i in range(num):
- self.assertTrue(timezone.is_localized(result[i]))
+ assert timezone.is_localized(result[i])
def test_delta_cron_presets(self):
preset_range = dates.date_range(datetime(2016, 1, 1), num=2, delta="@hourly")
timedelta_range = dates.date_range(datetime(2016, 1, 1), num=2, delta=timedelta(hours=1))
cron_range = dates.date_range(datetime(2016, 1, 1), num=2, delta="0 * * * *")
- self.assertEqual(preset_range, timedelta_range)
- self.assertEqual(preset_range, cron_range)
+ assert preset_range == timedelta_range
+ assert preset_range == cron_range
diff --git a/tests/utils/test_db.py b/tests/utils/test_db.py
index 70a1a9c12794e..e608ff1734c73 100644
--- a/tests/utils/test_db.py
+++ b/tests/utils/test_db.py
@@ -72,7 +72,7 @@ def test_database_schema_and_sqlalchemy_model_are_in_sync(self):
for ignore in ignores:
diff = [d for d in diff if not ignore(d)]
- self.assertFalse(diff, 'Database schema and SQLAlchemy model are not in sync: ' + str(diff))
+ assert not diff, 'Database schema and SQLAlchemy model are not in sync: ' + str(diff)
def test_only_single_head_revision_in_migrations(self):
config = Config()
@@ -87,4 +87,4 @@ def test_default_connections_sort(self):
pattern = re.compile('conn_id=[\"|\'](.*?)[\"|\']', re.DOTALL)
source = inspect.getsource(create_default_connections)
src = pattern.findall(source)
- self.assertListEqual(sorted(src), src)
+ assert sorted(src) == src
diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py
index 274eb77da0711..372d93d473676 100644
--- a/tests/utils/test_decorators.py
+++ b/tests/utils/test_decorators.py
@@ -18,6 +18,8 @@
import unittest
+import pytest
+
from airflow.exceptions import AirflowException
from airflow.utils.decorators import apply_defaults
@@ -39,34 +41,34 @@ def __init__(self, test_sub_param, **kwargs):
class TestApplyDefault(unittest.TestCase):
def test_apply(self):
dummy = DummyClass(test_param=True)
- self.assertTrue(dummy.test_param)
+ assert dummy.test_param
- with self.assertRaisesRegex(AirflowException, 'Argument.*test_param.*required'):
+ with pytest.raises(AirflowException, match='Argument.*test_param.*required'):
DummySubClass(test_sub_param=True)
def test_default_args(self):
default_args = {'test_param': True}
dummy_class = DummyClass(default_args=default_args) # pylint: disable=no-value-for-parameter
- self.assertTrue(dummy_class.test_param)
+ assert dummy_class.test_param
default_args = {'test_param': True, 'test_sub_param': True}
dummy_subclass = DummySubClass(default_args=default_args) # pylint: disable=no-value-for-parameter
- self.assertTrue(dummy_class.test_param)
- self.assertTrue(dummy_subclass.test_sub_param)
+ assert dummy_class.test_param
+ assert dummy_subclass.test_sub_param
default_args = {'test_param': True}
dummy_subclass = DummySubClass(default_args=default_args, test_sub_param=True)
- self.assertTrue(dummy_class.test_param)
- self.assertTrue(dummy_subclass.test_sub_param)
+ assert dummy_class.test_param
+ assert dummy_subclass.test_sub_param
- with self.assertRaisesRegex(AirflowException, 'Argument.*test_sub_param.*required'):
+ with pytest.raises(AirflowException, match='Argument.*test_sub_param.*required'):
DummySubClass(default_args=default_args) # pylint: disable=no-value-for-parameter
def test_incorrect_default_args(self):
default_args = {'test_param': True, 'extra_param': True}
dummy_class = DummyClass(default_args=default_args) # pylint: disable=no-value-for-parameter
- self.assertTrue(dummy_class.test_param)
+ assert dummy_class.test_param
default_args = {'random_params': True}
- with self.assertRaisesRegex(AirflowException, 'Argument.*test_param.*required'):
+ with pytest.raises(AirflowException, match='Argument.*test_param.*required'):
DummyClass(default_args=default_args) # pylint: disable=no-value-for-parameter
diff --git a/tests/utils/test_docs.py b/tests/utils/test_docs.py
index 3bcdb2510c10c..fd6a5845a4405 100644
--- a/tests/utils/test_docs.py
+++ b/tests/utils/test_docs.py
@@ -44,4 +44,4 @@ class TestGetDocsUrl(unittest.TestCase):
)
def test_should_return_link(self, version, page, expected_urk):
with mock.patch('airflow.version.version', version):
- self.assertEqual(expected_urk, get_docs_url(page))
+ assert expected_urk == get_docs_url(page)
diff --git a/tests/utils/test_dot_renderer.py b/tests/utils/test_dot_renderer.py
index 3016b0c32ff4c..b0306233fc9dc 100644
--- a/tests/utils/test_dot_renderer.py
+++ b/tests/utils/test_dot_renderer.py
@@ -45,14 +45,14 @@ def test_should_render_dag(self):
dot = dot_renderer.render_dag(dag)
source = dot.source
# Should render DAG title
- self.assertIn("label=DAG_ID", source)
- self.assertIn("first", source)
- self.assertIn("second", source)
- self.assertIn("third", source)
- self.assertIn("first -> second", source)
- self.assertIn("first -> third", source)
- self.assertIn('fillcolor="#f0ede4"', source)
- self.assertIn('fillcolor="#f0ede4"', source)
+ assert "label=DAG_ID" in source
+ assert "first" in source
+ assert "second" in source
+ assert "third" in source
+ assert "first -> second" in source
+ assert "first -> third" in source
+ assert 'fillcolor="#f0ede4"' in source
+ assert 'fillcolor="#f0ede4"' in source
def test_should_render_dag_with_task_instances(self):
dag = DAG(dag_id="DAG_ID")
@@ -71,10 +71,10 @@ def test_should_render_dag_with_task_instances(self):
dot = dot_renderer.render_dag(dag, tis=tis)
source = dot.source
# Should render DAG title
- self.assertIn("label=DAG_ID", source)
- self.assertIn('first [color=black fillcolor=tan shape=rectangle style="filled,rounded"]', source)
- self.assertIn('second [color=white fillcolor=green shape=rectangle style="filled,rounded"]', source)
- self.assertIn('third [color=black fillcolor=lime shape=rectangle style="filled,rounded"]', source)
+ assert "label=DAG_ID" in source
+ assert 'first [color=black fillcolor=tan shape=rectangle style="filled,rounded"]' in source
+ assert 'second [color=white fillcolor=green shape=rectangle style="filled,rounded"]' in source
+ assert 'third [color=black fillcolor=lime shape=rectangle style="filled,rounded"]' in source
def test_should_render_dag_orientation(self):
orientation = "TB"
@@ -94,8 +94,8 @@ def test_should_render_dag_orientation(self):
dot = dot_renderer.render_dag(dag, tis=tis)
source = dot.source
# Should render DAG title with orientation
- self.assertIn("label=DAG_ID", source)
- self.assertIn(f'label=DAG_ID labelloc=t rankdir={orientation}', source)
+ assert "label=DAG_ID" in source
+ assert f'label=DAG_ID labelloc=t rankdir={orientation}' in source
# Change orientation
orientation = "LR"
@@ -103,5 +103,5 @@ def test_should_render_dag_orientation(self):
dot = dot_renderer.render_dag(dag, tis=tis)
source = dot.source
# Should render DAG title with orientation
- self.assertIn("label=DAG_ID", source)
- self.assertIn(f'label=DAG_ID labelloc=t rankdir={orientation}', source)
+ assert "label=DAG_ID" in source
+ assert f'label=DAG_ID labelloc=t rankdir={orientation}' in source
diff --git a/tests/utils/test_email.py b/tests/utils/test_email.py
index 8966e975d3647..a34dc7d845766 100644
--- a/tests/utils/test_email.py
+++ b/tests/utils/test_email.py
@@ -24,6 +24,8 @@
from smtplib import SMTPServerDisconnected
from unittest import mock
+import pytest
+
from airflow import utils
from airflow.configuration import conf
from airflow.utils.email import build_mime_message, get_email_address_list
@@ -38,37 +40,39 @@ class TestEmail(unittest.TestCase):
def test_get_email_address_single_email(self):
emails_string = 'test1@example.com'
- self.assertEqual(get_email_address_list(emails_string), [emails_string])
+ assert get_email_address_list(emails_string) == [emails_string]
def test_get_email_address_comma_sep_string(self):
emails_string = 'test1@example.com, test2@example.com'
- self.assertEqual(get_email_address_list(emails_string), EMAILS)
+ assert get_email_address_list(emails_string) == EMAILS
def test_get_email_address_colon_sep_string(self):
emails_string = 'test1@example.com; test2@example.com'
- self.assertEqual(get_email_address_list(emails_string), EMAILS)
+ assert get_email_address_list(emails_string) == EMAILS
def test_get_email_address_list(self):
emails_list = ['test1@example.com', 'test2@example.com']
- self.assertEqual(get_email_address_list(emails_list), EMAILS)
+ assert get_email_address_list(emails_list) == EMAILS
def test_get_email_address_tuple(self):
emails_tuple = ('test1@example.com', 'test2@example.com')
- self.assertEqual(get_email_address_list(emails_tuple), EMAILS)
+ assert get_email_address_list(emails_tuple) == EMAILS
def test_get_email_address_invalid_type(self):
emails_string = 1
- self.assertRaises(TypeError, get_email_address_list, emails_string)
+ with pytest.raises(TypeError):
+ get_email_address_list(emails_string)
def test_get_email_address_invalid_type_in_iterable(self):
emails_list = ['test1@example.com', 2]
- self.assertRaises(TypeError, get_email_address_list, emails_list)
+ with pytest.raises(TypeError):
+ get_email_address_list(emails_list)
def setUp(self):
conf.remove_option('email', 'EMAIL_BACKEND')
@@ -77,7 +81,7 @@ def setUp(self):
def test_default_backend(self, mock_send_email):
res = utils.email.send_email('to', 'subject', 'content')
mock_send_email.assert_called_once_with('to', 'subject', 'content')
- self.assertEqual(mock_send_email.return_value, res)
+ assert mock_send_email.return_value == res
@mock.patch('airflow.utils.email.send_email_smtp')
def test_custom_backend(self, mock_send_email):
@@ -94,7 +98,7 @@ def test_custom_backend(self, mock_send_email):
mime_charset='utf-8',
mime_subtype='mixed',
)
- self.assertFalse(mock_send_email.called)
+ assert not mock_send_email.called
def test_build_mime_message(self):
mail_from = 'from@example.com'
@@ -111,12 +115,12 @@ def test_build_mime_message(self):
custom_headers=custom_headers,
)
- self.assertIn('From', msg)
- self.assertIn('To', msg)
- self.assertIn('Subject', msg)
- self.assertIn('Reply-To', msg)
- self.assertListEqual([mail_to], recipients)
- self.assertEqual(msg['To'], ','.join(recipients))
+ assert 'From' in msg
+ assert 'To' in msg
+ assert 'Subject' in msg
+ assert 'Reply-To' in msg
+ assert [mail_to] == recipients
+ assert msg['To'] == ','.join(recipients)
class TestEmailSmtp(unittest.TestCase):
@@ -126,27 +130,27 @@ def test_send_smtp(self, mock_send_mime):
attachment.write(b'attachment')
attachment.seek(0)
utils.email.send_email_smtp('to', 'subject', 'content', files=[attachment.name])
- self.assertTrue(mock_send_mime.called)
+ assert mock_send_mime.called
_, call_args = mock_send_mime.call_args
- self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), call_args['e_from'])
- self.assertEqual(['to'], call_args['e_to'])
+ assert conf.get('smtp', 'SMTP_MAIL_FROM') == call_args['e_from']
+ assert ['to'] == call_args['e_to']
msg = call_args['mime_msg']
- self.assertEqual('subject', msg['Subject'])
- self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), msg['From'])
- self.assertEqual(2, len(msg.get_payload()))
+ assert 'subject' == msg['Subject']
+ assert conf.get('smtp', 'SMTP_MAIL_FROM') == msg['From']
+ assert 2 == len(msg.get_payload())
filename = 'attachment; filename="' + os.path.basename(attachment.name) + '"'
- self.assertEqual(filename, msg.get_payload()[-1].get('Content-Disposition'))
+ assert filename == msg.get_payload()[-1].get('Content-Disposition')
mimeapp = MIMEApplication('attachment')
- self.assertEqual(mimeapp.get_payload(), msg.get_payload()[-1].get_payload())
+ assert mimeapp.get_payload() == msg.get_payload()[-1].get_payload()
@mock.patch('airflow.utils.email.send_mime_email')
def test_send_smtp_with_multibyte_content(self, mock_send_mime):
utils.email.send_email_smtp('to', 'subject', '🔥', mime_charset='utf-8')
- self.assertTrue(mock_send_mime.called)
+ assert mock_send_mime.called
_, call_args = mock_send_mime.call_args
msg = call_args['mime_msg']
mimetext = MIMEText('🔥', 'mixed', 'utf-8')
- self.assertEqual(mimetext.get_payload(), msg.get_payload()[0].get_payload())
+ assert mimetext.get_payload() == msg.get_payload()[0].get_payload()
@mock.patch('airflow.utils.email.send_mime_email')
def test_send_bcc_smtp(self, mock_send_mime):
@@ -154,20 +158,19 @@ def test_send_bcc_smtp(self, mock_send_mime):
attachment.write(b'attachment')
attachment.seek(0)
utils.email.send_email_smtp('to', 'subject', 'content', files=[attachment.name], cc='cc', bcc='bcc')
- self.assertTrue(mock_send_mime.called)
+ assert mock_send_mime.called
_, call_args = mock_send_mime.call_args
- self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), call_args['e_from'])
- self.assertEqual(['to', 'cc', 'bcc'], call_args['e_to'])
+ assert conf.get('smtp', 'SMTP_MAIL_FROM') == call_args['e_from']
+ assert ['to', 'cc', 'bcc'] == call_args['e_to']
msg = call_args['mime_msg']
- self.assertEqual('subject', msg['Subject'])
- self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), msg['From'])
- self.assertEqual(2, len(msg.get_payload()))
- self.assertEqual(
- 'attachment; filename="' + os.path.basename(attachment.name) + '"',
- msg.get_payload()[-1].get('Content-Disposition'),
- )
+ assert 'subject' == msg['Subject']
+ assert conf.get('smtp', 'SMTP_MAIL_FROM') == msg['From']
+ assert 2 == len(msg.get_payload())
+ assert 'attachment; filename="' + os.path.basename(attachment.name) + '"' == msg.get_payload()[
+ -1
+ ].get('Content-Disposition')
mimeapp = MIMEApplication('attachment')
- self.assertEqual(mimeapp.get_payload(), msg.get_payload()[-1].get_payload())
+ assert mimeapp.get_payload() == msg.get_payload()[-1].get_payload()
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
@@ -180,14 +183,14 @@ def test_send_mime(self, mock_smtp, mock_smtp_ssl):
port=conf.getint('smtp', 'SMTP_PORT'),
timeout=conf.getint('smtp', 'SMTP_TIMEOUT'),
)
- self.assertFalse(mock_smtp_ssl.called)
- self.assertTrue(mock_smtp.return_value.starttls.called)
+ assert not mock_smtp_ssl.called
+ assert mock_smtp.return_value.starttls.called
mock_smtp.return_value.login.assert_called_once_with(
conf.get('smtp', 'SMTP_USER'),
conf.get('smtp', 'SMTP_PASSWORD'),
)
mock_smtp.return_value.sendmail.assert_called_once_with('from', 'to', msg.as_string())
- self.assertTrue(mock_smtp.return_value.quit.called)
+ assert mock_smtp.return_value.quit.called
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
@@ -195,7 +198,7 @@ def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl):
mock_smtp_ssl.return_value = mock.Mock()
with conf_vars({('smtp', 'smtp_ssl'): 'True'}):
utils.email.send_mime_email('from', 'to', MIMEMultipart(), dryrun=False)
- self.assertFalse(mock_smtp.called)
+ assert not mock_smtp.called
mock_smtp_ssl.assert_called_once_with(
host=conf.get('smtp', 'SMTP_HOST'),
port=conf.getint('smtp', 'SMTP_PORT'),
@@ -213,27 +216,27 @@ def test_send_mime_noauth(self, mock_smtp, mock_smtp_ssl):
}
):
utils.email.send_mime_email('from', 'to', MIMEMultipart(), dryrun=False)
- self.assertFalse(mock_smtp_ssl.called)
+ assert not mock_smtp_ssl.called
mock_smtp.assert_called_once_with(
host=conf.get('smtp', 'SMTP_HOST'),
port=conf.getint('smtp', 'SMTP_PORT'),
timeout=conf.getint('smtp', 'SMTP_TIMEOUT'),
)
- self.assertFalse(mock_smtp.login.called)
+ assert not mock_smtp.login.called
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_dryrun(self, mock_smtp, mock_smtp_ssl):
utils.email.send_mime_email('from', 'to', MIMEMultipart(), dryrun=True)
- self.assertFalse(mock_smtp.called)
- self.assertFalse(mock_smtp_ssl.called)
+ assert not mock_smtp.called
+ assert not mock_smtp_ssl.called
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_complete_failure(self, mock_smtp: mock, mock_smtp_ssl):
mock_smtp.side_effect = SMTPServerDisconnected()
msg = MIMEMultipart()
- with self.assertRaises(SMTPServerDisconnected):
+ with pytest.raises(SMTPServerDisconnected):
utils.email.send_mime_email('from', 'to', msg, dryrun=False)
mock_smtp.assert_any_call(
@@ -241,12 +244,12 @@ def test_send_mime_complete_failure(self, mock_smtp: mock, mock_smtp_ssl):
port=conf.getint('smtp', 'SMTP_PORT'),
timeout=conf.getint('smtp', 'SMTP_TIMEOUT'),
)
- self.assertEqual(mock_smtp.call_count, conf.getint('smtp', 'SMTP_RETRY_LIMIT'))
- self.assertFalse(mock_smtp_ssl.called)
- self.assertFalse(mock_smtp.return_value.starttls.called)
- self.assertFalse(mock_smtp.return_value.login.called)
- self.assertFalse(mock_smtp.return_value.sendmail.called)
- self.assertFalse(mock_smtp.return_value.quit.called)
+ assert mock_smtp.call_count == conf.getint('smtp', 'SMTP_RETRY_LIMIT')
+ assert not mock_smtp_ssl.called
+ assert not mock_smtp.return_value.starttls.called
+ assert not mock_smtp.return_value.login.called
+ assert not mock_smtp.return_value.sendmail.called
+ assert not mock_smtp.return_value.quit.called
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
@@ -254,7 +257,7 @@ def test_send_mime_ssl_complete_failure(self, mock_smtp, mock_smtp_ssl):
mock_smtp_ssl.side_effect = SMTPServerDisconnected()
msg = MIMEMultipart()
with conf_vars({('smtp', 'smtp_ssl'): 'True'}):
- with self.assertRaises(SMTPServerDisconnected):
+ with pytest.raises(SMTPServerDisconnected):
utils.email.send_mime_email('from', 'to', msg, dryrun=False)
mock_smtp_ssl.assert_any_call(
@@ -262,12 +265,12 @@ def test_send_mime_ssl_complete_failure(self, mock_smtp, mock_smtp_ssl):
port=conf.getint('smtp', 'SMTP_PORT'),
timeout=conf.getint('smtp', 'SMTP_TIMEOUT'),
)
- self.assertEqual(mock_smtp_ssl.call_count, conf.getint('smtp', 'SMTP_RETRY_LIMIT'))
- self.assertFalse(mock_smtp.called)
- self.assertFalse(mock_smtp_ssl.return_value.starttls.called)
- self.assertFalse(mock_smtp_ssl.return_value.login.called)
- self.assertFalse(mock_smtp_ssl.return_value.sendmail.called)
- self.assertFalse(mock_smtp_ssl.return_value.quit.called)
+ assert mock_smtp_ssl.call_count == conf.getint('smtp', 'SMTP_RETRY_LIMIT')
+ assert not mock_smtp.called
+ assert not mock_smtp_ssl.return_value.starttls.called
+ assert not mock_smtp_ssl.return_value.login.called
+ assert not mock_smtp_ssl.return_value.sendmail.called
+ assert not mock_smtp_ssl.return_value.quit.called
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
@@ -284,14 +287,14 @@ def test_send_mime_custom_timeout_retrylimit(self, mock_smtp, mock_smtp_ssl):
('smtp', 'smtp_timeout'): str(custom_timeout),
}
):
- with self.assertRaises(SMTPServerDisconnected):
+ with pytest.raises(SMTPServerDisconnected):
utils.email.send_mime_email('from', 'to', msg, dryrun=False)
mock_smtp.assert_any_call(
host=conf.get('smtp', 'SMTP_HOST'), port=conf.getint('smtp', 'SMTP_PORT'), timeout=custom_timeout
)
- self.assertFalse(mock_smtp_ssl.called)
- self.assertEqual(mock_smtp.call_count, 10)
+ assert not mock_smtp_ssl.called
+ assert mock_smtp.call_count == 10
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
@@ -308,8 +311,8 @@ def test_send_mime_partial_failure(self, mock_smtp, mock_smtp_ssl):
port=conf.getint('smtp', 'SMTP_PORT'),
timeout=conf.getint('smtp', 'SMTP_TIMEOUT'),
)
- self.assertEqual(mock_smtp.call_count, side_effects.index(final_mock) + 1)
- self.assertFalse(mock_smtp_ssl.called)
- self.assertTrue(final_mock.starttls.called)
+ assert mock_smtp.call_count == side_effects.index(final_mock) + 1
+ assert not mock_smtp_ssl.called
+ assert final_mock.starttls.called
final_mock.sendmail.assert_called_once_with('from', 'to', msg.as_string())
- self.assertTrue(final_mock.quit.called)
+ assert final_mock.quit.called
diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py
index 26ad118753f75..fffa2d44a8e10 100644
--- a/tests/utils/test_helpers.py
+++ b/tests/utils/test_helpers.py
@@ -19,6 +19,8 @@
import unittest
from datetime import datetime
+import pytest
+
from airflow.models import TaskInstance
from airflow.models.dag import DAG
from airflow.operators.dummy import DummyOperator
@@ -47,59 +49,59 @@ def test_render_log_filename(self):
rendered_filename = helpers.render_log_filename(ti, try_number, filename_template)
- self.assertEqual(rendered_filename, expected_filename)
+ assert rendered_filename == expected_filename
def test_chunks(self):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
list(helpers.chunks([1, 2, 3], 0))
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
list(helpers.chunks([1, 2, 3], -3))
- self.assertEqual(list(helpers.chunks([], 5)), [])
- self.assertEqual(list(helpers.chunks([1], 1)), [[1]])
- self.assertEqual(list(helpers.chunks([1, 2, 3], 2)), [[1, 2], [3]])
+ assert list(helpers.chunks([], 5)) == []
+ assert list(helpers.chunks([1], 1)) == [[1]]
+ assert list(helpers.chunks([1, 2, 3], 2)) == [[1, 2], [3]]
def test_reduce_in_chunks(self):
- self.assertEqual(
- helpers.reduce_in_chunks(lambda x, y: x + [y], [1, 2, 3, 4, 5], []), [[1, 2, 3, 4, 5]]
- )
+ assert helpers.reduce_in_chunks(lambda x, y: x + [y], [1, 2, 3, 4, 5], []) == [[1, 2, 3, 4, 5]]
- self.assertEqual(
- helpers.reduce_in_chunks(lambda x, y: x + [y], [1, 2, 3, 4, 5], [], 2), [[1, 2], [3, 4], [5]]
- )
+ assert helpers.reduce_in_chunks(lambda x, y: x + [y], [1, 2, 3, 4, 5], [], 2) == [[1, 2], [3, 4], [5]]
- self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + y[0] * y[1], [1, 2, 3, 4], 0, 2), 14)
+ assert helpers.reduce_in_chunks(lambda x, y: x + y[0] * y[1], [1, 2, 3, 4], 0, 2) == 14
def test_is_container(self):
- self.assertFalse(helpers.is_container("a string is not a container"))
- self.assertTrue(helpers.is_container(["a", "list", "is", "a", "container"]))
+ assert not helpers.is_container("a string is not a container")
+ assert helpers.is_container(["a", "list", "is", "a", "container"])
- self.assertTrue(helpers.is_container(['test_list']))
- self.assertFalse(helpers.is_container('test_str_not_iterable'))
+ assert helpers.is_container(['test_list'])
+ assert not helpers.is_container('test_str_not_iterable')
# Pass an object that is not iter nor a string.
- self.assertFalse(helpers.is_container(10))
+ assert not helpers.is_container(10)
def test_as_tuple(self):
- self.assertEqual(helpers.as_tuple("a string is not a container"), ("a string is not a container",))
-
- self.assertEqual(
- helpers.as_tuple(["a", "list", "is", "a", "container"]), ("a", "list", "is", "a", "container")
+ assert helpers.as_tuple("a string is not a container") == ("a string is not a container",)
+
+ assert helpers.as_tuple(["a", "list", "is", "a", "container"]) == (
+ "a",
+ "list",
+ "is",
+ "a",
+ "container",
)
def test_as_tuple_iter(self):
test_list = ['test_str']
as_tup = helpers.as_tuple(test_list)
- self.assertTupleEqual(tuple(test_list), as_tup)
+ assert tuple(test_list) == as_tup
def test_as_tuple_no_iter(self):
test_str = 'test_str'
as_tup = helpers.as_tuple(test_str)
- self.assertTupleEqual((test_str,), as_tup)
+ assert (test_str,) == as_tup
def test_convert_camel_to_snake(self):
- self.assertEqual(helpers.convert_camel_to_snake('LocalTaskJob'), 'local_task_job')
- self.assertEqual(helpers.convert_camel_to_snake('somethingVeryRandom'), 'something_very_random')
+ assert helpers.convert_camel_to_snake('LocalTaskJob') == 'local_task_job'
+ assert helpers.convert_camel_to_snake('somethingVeryRandom') == 'something_very_random'
def test_merge_dicts(self):
"""
@@ -108,7 +110,7 @@ def test_merge_dicts(self):
dict1 = {'a': 1, 'b': 2, 'c': 3}
dict2 = {'a': 1, 'b': 3, 'd': 42}
merged = merge_dicts(dict1, dict2)
- self.assertDictEqual(merged, {'a': 1, 'b': 3, 'c': 3, 'd': 42})
+ assert merged == {'a': 1, 'b': 3, 'c': 3, 'd': 42}
def test_merge_dicts_recursive_overlap_l1(self):
"""
@@ -117,7 +119,7 @@ def test_merge_dicts_recursive_overlap_l1(self):
dict1 = {'a': 1, 'r': {'a': 1, 'b': 2}}
dict2 = {'a': 1, 'r': {'c': 3, 'b': 0}}
merged = merge_dicts(dict1, dict2)
- self.assertDictEqual(merged, {'a': 1, 'r': {'a': 1, 'b': 0, 'c': 3}})
+ assert merged == {'a': 1, 'r': {'a': 1, 'b': 0, 'c': 3}}
def test_merge_dicts_recursive_overlap_l2(self):
"""
@@ -127,7 +129,7 @@ def test_merge_dicts_recursive_overlap_l2(self):
dict1 = {'a': 1, 'r': {'a': 1, 'b': {'a': 1}}}
dict2 = {'a': 1, 'r': {'c': 3, 'b': {'b': 1}}}
merged = merge_dicts(dict1, dict2)
- self.assertDictEqual(merged, {'a': 1, 'r': {'a': 1, 'b': {'a': 1, 'b': 1}, 'c': 3}})
+ assert merged == {'a': 1, 'r': {'a': 1, 'b': {'a': 1, 'b': 1}, 'c': 3}}
def test_merge_dicts_recursive_right_only(self):
"""
@@ -136,7 +138,7 @@ def test_merge_dicts_recursive_right_only(self):
dict1 = {'a': 1}
dict2 = {'a': 1, 'r': {'c': 3, 'b': 0}}
merged = merge_dicts(dict1, dict2)
- self.assertDictEqual(merged, {'a': 1, 'r': {'b': 0, 'c': 3}})
+ assert merged == {'a': 1, 'r': {'b': 0, 'c': 3}}
@conf_vars(
{
diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py
index a127ff98e6cab..ebe757e74ec9a 100644
--- a/tests/utils/test_json.py
+++ b/tests/utils/test_json.py
@@ -21,6 +21,7 @@
from datetime import date, datetime
import numpy as np
+import pytest
from airflow.utils import json as utils_json
@@ -28,19 +29,19 @@
class TestAirflowJsonEncoder(unittest.TestCase):
def test_encode_datetime(self):
obj = datetime.strptime('2017-05-21 00:00:00', '%Y-%m-%d %H:%M:%S')
- self.assertEqual(json.dumps(obj, cls=utils_json.AirflowJsonEncoder), '"2017-05-21T00:00:00Z"')
+ assert json.dumps(obj, cls=utils_json.AirflowJsonEncoder) == '"2017-05-21T00:00:00Z"'
def test_encode_date(self):
- self.assertEqual(json.dumps(date(2017, 5, 21), cls=utils_json.AirflowJsonEncoder), '"2017-05-21"')
+ assert json.dumps(date(2017, 5, 21), cls=utils_json.AirflowJsonEncoder) == '"2017-05-21"'
def test_encode_numpy_int(self):
- self.assertEqual(json.dumps(np.int32(5), cls=utils_json.AirflowJsonEncoder), '5')
+ assert json.dumps(np.int32(5), cls=utils_json.AirflowJsonEncoder) == '5'
def test_encode_numpy_bool(self):
- self.assertEqual(json.dumps(np.bool_(True), cls=utils_json.AirflowJsonEncoder), 'true')
+ assert json.dumps(np.bool_(True), cls=utils_json.AirflowJsonEncoder) == 'true'
def test_encode_numpy_float(self):
- self.assertEqual(json.dumps(np.float16(3.76953125), cls=utils_json.AirflowJsonEncoder), '3.76953125')
+ assert json.dumps(np.float16(3.76953125), cls=utils_json.AirflowJsonEncoder) == '3.76953125'
def test_encode_k8s_v1pod(self):
from kubernetes.client import models as k8s
@@ -59,19 +60,14 @@ def test_encode_k8s_v1pod(self):
]
),
)
- self.assertEqual(
- json.loads(json.dumps(pod, cls=utils_json.AirflowJsonEncoder)),
- {
- "metadata": {"name": "foo", "namespace": "bar"},
- "spec": {"containers": [{"image": "bar", "name": "foo"}]},
- },
- )
+ assert json.loads(json.dumps(pod, cls=utils_json.AirflowJsonEncoder)) == {
+ "metadata": {"name": "foo", "namespace": "bar"},
+ "spec": {"containers": [{"image": "bar", "name": "foo"}]},
+ }
def test_encode_raises(self):
- self.assertRaisesRegex(
- TypeError,
- "^.*is not JSON serializable$",
- json.dumps,
- Exception,
- cls=utils_json.AirflowJsonEncoder,
- )
+ with pytest.raises(TypeError, match="^.*is not JSON serializable$"):
+ json.dumps(
+ Exception,
+ cls=utils_json.AirflowJsonEncoder,
+ )
diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py
index 3b747ee2ce4de..76115a2f3a279 100644
--- a/tests/utils/test_log_handlers.py
+++ b/tests/utils/test_log_handlers.py
@@ -19,6 +19,7 @@
import logging
import logging.config
import os
+import re
import unittest
from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG
@@ -58,9 +59,9 @@ def test_default_task_logging_setup(self):
# file task handler is used by default.
logger = logging.getLogger(TASK_LOGGER)
handlers = logger.handlers
- self.assertEqual(len(handlers), 1)
+ assert len(handlers) == 1
handler = handlers[0]
- self.assertEqual(handler.name, FILE_TASK_HANDLER)
+ assert handler.name == FILE_TASK_HANDLER
def test_file_task_handler(self):
def task_callable(ti, **kwargs):
@@ -81,33 +82,33 @@ def task_callable(ti, **kwargs):
file_handler = next(
(handler for handler in logger.handlers if handler.name == FILE_TASK_HANDLER), None
)
- self.assertIsNotNone(file_handler)
+ assert file_handler is not None
set_context(logger, ti)
- self.assertIsNotNone(file_handler.handler)
+ assert file_handler.handler is not None
# We expect set_context generates a file locally.
log_filename = file_handler.handler.baseFilename
- self.assertTrue(os.path.isfile(log_filename))
- self.assertTrue(log_filename.endswith("1.log"), log_filename)
+ assert os.path.isfile(log_filename)
+ assert log_filename.endswith("1.log"), log_filename
ti.run(ignore_ti_state=True)
file_handler.flush()
file_handler.close()
- self.assertTrue(hasattr(file_handler, 'read'))
+ assert hasattr(file_handler, 'read')
# Return value of read must be a tuple of list and list.
logs, metadatas = file_handler.read(ti)
- self.assertTrue(isinstance(logs, list))
- self.assertTrue(isinstance(metadatas, list))
- self.assertEqual(len(logs), 1)
- self.assertEqual(len(logs), len(metadatas))
- self.assertTrue(isinstance(metadatas[0], dict))
+ assert isinstance(logs, list)
+ assert isinstance(metadatas, list)
+ assert len(logs) == 1
+ assert len(logs) == len(metadatas)
+ assert isinstance(metadatas[0], dict)
target_re = r'\n\[[^\]]+\] {test_log_handlers.py:\d+} INFO - test\n'
# We should expect our log line from the callable above to appear in
# the logs we read back
- self.assertRegex(logs[0][0][-1], target_re, "Logs were " + str(logs))
+ assert re.search(target_re, logs[0][0][-1]), "Logs were " + str(logs)
# Remove the generated tmp log file.
os.remove(log_filename)
@@ -132,26 +133,26 @@ def task_callable(ti, **kwargs):
file_handler = next(
(handler for handler in logger.handlers if handler.name == FILE_TASK_HANDLER), None
)
- self.assertIsNotNone(file_handler)
+ assert file_handler is not None
set_context(logger, ti)
- self.assertIsNotNone(file_handler.handler)
+ assert file_handler.handler is not None
# We expect set_context generates a file locally.
log_filename = file_handler.handler.baseFilename
- self.assertTrue(os.path.isfile(log_filename))
- self.assertTrue(log_filename.endswith("2.log"), log_filename)
+ assert os.path.isfile(log_filename)
+ assert log_filename.endswith("2.log"), log_filename
logger.info("Test")
# Return value of read must be a tuple of list and list.
logs, metadatas = file_handler.read(ti)
- self.assertTrue(isinstance(logs, list))
+ assert isinstance(logs, list)
# Logs for running tasks should show up too.
- self.assertTrue(isinstance(logs, list))
- self.assertTrue(isinstance(metadatas, list))
- self.assertEqual(len(logs), 2)
- self.assertEqual(len(logs), len(metadatas))
- self.assertTrue(isinstance(metadatas[0], dict))
+ assert isinstance(logs, list)
+ assert isinstance(metadatas, list)
+ assert len(logs) == 2
+ assert len(logs) == len(metadatas)
+ assert isinstance(metadatas[0], dict)
# Remove the generated tmp log file.
os.remove(log_filename)
@@ -171,7 +172,7 @@ def test_python_formatting(self):
fth = FileTaskHandler('', '{dag_id}/{task_id}/{execution_date}/{try_number}.log')
rendered_filename = fth._render_filename(self.ti, 42)
- self.assertEqual(expected_filename, rendered_filename)
+ assert expected_filename == rendered_filename
def test_jinja_rendering(self):
expected_filename = (
@@ -181,4 +182,4 @@ def test_jinja_rendering(self):
fth = FileTaskHandler('', '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log')
rendered_filename = fth._render_filename(self.ti, 42)
- self.assertEqual(expected_filename, rendered_filename)
+ assert expected_filename == rendered_filename
diff --git a/tests/utils/test_logging_mixin.py b/tests/utils/test_logging_mixin.py
index 583eefd4bccbf..7d19b9d5632a0 100644
--- a/tests/utils/test_logging_mixin.py
+++ b/tests/utils/test_logging_mixin.py
@@ -62,12 +62,12 @@ def test_write(self):
msg = "test_message"
log.write(msg)
- self.assertEqual(log._buffer, msg)
+ assert log._buffer == msg
log.write(" \n")
logger.log.assert_called_once_with(1, msg)
- self.assertEqual(log._buffer, "")
+ assert log._buffer == ""
def test_flush(self):
logger = mock.MagicMock()
@@ -78,30 +78,30 @@ def test_flush(self):
msg = "test_message"
log.write(msg)
- self.assertEqual(log._buffer, msg)
+ assert log._buffer == msg
log.flush()
logger.log.assert_called_once_with(1, msg)
- self.assertEqual(log._buffer, "")
+ assert log._buffer == ""
def test_isatty(self):
logger = mock.MagicMock()
logger.log = mock.MagicMock()
log = StreamLogWriter(logger, 1)
- self.assertFalse(log.isatty())
+ assert not log.isatty()
def test_encoding(self):
logger = mock.MagicMock()
logger.log = mock.MagicMock()
log = StreamLogWriter(logger, 1)
- self.assertIsNone(log.encoding)
+ assert log.encoding is None
def test_iobase_compatibility(self):
log = StreamLogWriter(None, 1)
- self.assertFalse(log.closed)
+ assert not log.closed
# has no specific effect
log.close()
diff --git a/tests/utils/test_module_loading.py b/tests/utils/test_module_loading.py
index 94644e023ce90..51daee67b22f4 100644
--- a/tests/utils/test_module_loading.py
+++ b/tests/utils/test_module_loading.py
@@ -18,17 +18,19 @@
import unittest
+import pytest
+
from airflow.utils.module_loading import import_string
class TestModuleImport(unittest.TestCase):
def test_import_string(self):
cls = import_string('airflow.utils.module_loading.import_string')
- self.assertEqual(cls, import_string)
+ assert cls == import_string # pylint: disable=comparison-with-callable
# Test exceptions raised
- with self.assertRaises(ImportError):
+ with pytest.raises(ImportError):
import_string('no_dots_in_path')
msg = 'Module "airflow.utils" does not define a "nonexistent" attribute'
- with self.assertRaisesRegex(ImportError, msg):
+ with pytest.raises(ImportError, match=msg):
import_string('airflow.utils.nonexistent')
diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py
index 2f06371018a7f..3ac2e344aff2a 100644
--- a/tests/utils/test_net.py
+++ b/tests/utils/test_net.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.exceptions import AirflowConfigException
from airflow.utils import net
from tests.test_utils.config import conf_vars
@@ -32,22 +34,22 @@ class TestGetHostname(unittest.TestCase):
@mock.patch('socket.getfqdn', return_value='first')
@conf_vars({('core', 'hostname_callable'): None})
def test_get_hostname_unset(self, mock_getfqdn):
- self.assertEqual('first', net.get_hostname())
+ assert 'first' == net.get_hostname()
@conf_vars({('core', 'hostname_callable'): 'tests.utils.test_net.get_hostname'})
def test_get_hostname_set(self):
- self.assertEqual('awesomehostname', net.get_hostname())
+ assert 'awesomehostname' == net.get_hostname()
@conf_vars({('core', 'hostname_callable'): 'tests.utils.test_net'})
def test_get_hostname_set_incorrect(self):
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
net.get_hostname()
@conf_vars({('core', 'hostname_callable'): 'tests.utils.test_net.missing_func'})
def test_get_hostname_set_missing(self):
- with self.assertRaisesRegex(
+ with pytest.raises(
AirflowConfigException,
- re.escape(
+ match=re.escape(
'The object could not be loaded. Please check "hostname_callable" key in "core" section. '
'Current value: "tests.utils.test_net.missing_func"'
),
diff --git a/tests/utils/test_operator_helpers.py b/tests/utils/test_operator_helpers.py
index 89cdf8abd5e9f..46d8b8b27e022 100644
--- a/tests/utils/test_operator_helpers.py
+++ b/tests/utils/test_operator_helpers.py
@@ -50,32 +50,26 @@ def setUp(self):
}
def test_context_to_airflow_vars_empty_context(self):
- self.assertDictEqual(operator_helpers.context_to_airflow_vars({}), {})
+ assert operator_helpers.context_to_airflow_vars({}) == {}
def test_context_to_airflow_vars_all_context(self):
- self.assertDictEqual(
- operator_helpers.context_to_airflow_vars(self.context),
- {
- 'airflow.ctx.dag_id': self.dag_id,
- 'airflow.ctx.execution_date': self.execution_date,
- 'airflow.ctx.task_id': self.task_id,
- 'airflow.ctx.dag_run_id': self.dag_run_id,
- 'airflow.ctx.dag_owner': 'owner1,owner2',
- 'airflow.ctx.dag_email': 'email1@test.com',
- },
- )
-
- self.assertDictEqual(
- operator_helpers.context_to_airflow_vars(self.context, in_env_var_format=True),
- {
- 'AIRFLOW_CTX_DAG_ID': self.dag_id,
- 'AIRFLOW_CTX_EXECUTION_DATE': self.execution_date,
- 'AIRFLOW_CTX_TASK_ID': self.task_id,
- 'AIRFLOW_CTX_DAG_RUN_ID': self.dag_run_id,
- 'AIRFLOW_CTX_DAG_OWNER': 'owner1,owner2',
- 'AIRFLOW_CTX_DAG_EMAIL': 'email1@test.com',
- },
- )
+ assert operator_helpers.context_to_airflow_vars(self.context) == {
+ 'airflow.ctx.dag_id': self.dag_id,
+ 'airflow.ctx.execution_date': self.execution_date,
+ 'airflow.ctx.task_id': self.task_id,
+ 'airflow.ctx.dag_run_id': self.dag_run_id,
+ 'airflow.ctx.dag_owner': 'owner1,owner2',
+ 'airflow.ctx.dag_email': 'email1@test.com',
+ }
+
+ assert operator_helpers.context_to_airflow_vars(self.context, in_env_var_format=True) == {
+ 'AIRFLOW_CTX_DAG_ID': self.dag_id,
+ 'AIRFLOW_CTX_EXECUTION_DATE': self.execution_date,
+ 'AIRFLOW_CTX_TASK_ID': self.task_id,
+ 'AIRFLOW_CTX_DAG_RUN_ID': self.dag_run_id,
+ 'AIRFLOW_CTX_DAG_OWNER': 'owner1,owner2',
+ 'AIRFLOW_CTX_DAG_EMAIL': 'email1@test.com',
+ }
def callable1(ds_nodash):
diff --git a/tests/utils/test_process_utils.py b/tests/utils/test_process_utils.py
index dc1785486e8fc..2c14ae42a47a5 100644
--- a/tests/utils/test_process_utils.py
+++ b/tests/utils/test_process_utils.py
@@ -80,14 +80,14 @@ def test_reap_process_group(self):
parent = multiprocessing.Process(target=TestReapProcessGroup._parent_of_ignores_sigterm, args=args)
try:
parent.start()
- self.assertTrue(parent_setup_done.acquire(timeout=5.0))
- self.assertTrue(psutil.pid_exists(parent_pid.value))
- self.assertTrue(psutil.pid_exists(child_pid.value))
+ assert parent_setup_done.acquire(timeout=5.0)
+ assert psutil.pid_exists(parent_pid.value)
+ assert psutil.pid_exists(child_pid.value)
process_utils.reap_process_group(parent_pid.value, logging.getLogger(), timeout=1)
- self.assertFalse(psutil.pid_exists(parent_pid.value))
- self.assertFalse(psutil.pid_exists(child_pid.value))
+ assert not psutil.pid_exists(parent_pid.value)
+ assert not psutil.pid_exists(child_pid.value)
finally:
try:
os.kill(parent_pid.value, signal.SIGKILL) # terminate doesnt work here
@@ -103,10 +103,10 @@ def test_should_print_all_messages1(self):
msgs = [record.getMessage() for record in logs.records]
- self.assertEqual(["Executing cmd: bash -c 'echo CAT; echo KITTY;'", 'Output:', 'CAT', 'KITTY'], msgs)
+ assert ["Executing cmd: bash -c 'echo CAT; echo KITTY;'", 'Output:', 'CAT', 'KITTY'] == msgs
def test_should_raise_exception(self):
- with self.assertRaises(CalledProcessError):
+ with pytest.raises(CalledProcessError):
process_utils.execute_in_subprocess(["bash", "-c", "exit 1"])
@@ -129,12 +129,12 @@ def test_should_kill_process(self):
sleep(0)
num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")
- self.assertEqual(before_num_process + 1, num_process)
+ assert before_num_process + 1 == num_process
process_utils.kill_child_processes_by_pids([process.pid])
num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")
- self.assertEqual(before_num_process, num_process)
+ assert before_num_process == num_process
@pytest.mark.quarantined
def test_should_force_kill_process(self):
@@ -145,14 +145,14 @@ def test_should_force_kill_process(self):
sleep(0)
num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")
- self.assertEqual(before_num_process + 1, num_process)
+ assert before_num_process + 1 == num_process
with self.assertLogs(process_utils.log) as cm:
process_utils.kill_child_processes_by_pids([process.pid], timeout=0)
- self.assertTrue(any("Killing child PID" in line for line in cm.output))
+ assert any("Killing child PID" in line for line in cm.output)
num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")
- self.assertEqual(before_num_process, num_process)
+ assert before_num_process == num_process
class TestPatchEnviron(unittest.TestCase):
@@ -160,31 +160,31 @@ def test_should_update_variable_and_restore_state_when_exit(self):
with mock.patch.dict("os.environ", {"TEST_NOT_EXISTS": "BEFORE", "TEST_EXISTS": "BEFORE"}):
del os.environ["TEST_NOT_EXISTS"]
- self.assertEqual("BEFORE", os.environ["TEST_EXISTS"])
- self.assertNotIn("TEST_NOT_EXISTS", os.environ)
+ assert "BEFORE" == os.environ["TEST_EXISTS"]
+ assert "TEST_NOT_EXISTS" not in os.environ
with process_utils.patch_environ({"TEST_NOT_EXISTS": "AFTER", "TEST_EXISTS": "AFTER"}):
- self.assertEqual("AFTER", os.environ["TEST_NOT_EXISTS"])
- self.assertEqual("AFTER", os.environ["TEST_EXISTS"])
+ assert "AFTER" == os.environ["TEST_NOT_EXISTS"]
+ assert "AFTER" == os.environ["TEST_EXISTS"]
- self.assertEqual("BEFORE", os.environ["TEST_EXISTS"])
- self.assertNotIn("TEST_NOT_EXISTS", os.environ)
+ assert "BEFORE" == os.environ["TEST_EXISTS"]
+ assert "TEST_NOT_EXISTS" not in os.environ
def test_should_restore_state_when_exception(self):
with mock.patch.dict("os.environ", {"TEST_NOT_EXISTS": "BEFORE", "TEST_EXISTS": "BEFORE"}):
del os.environ["TEST_NOT_EXISTS"]
- self.assertEqual("BEFORE", os.environ["TEST_EXISTS"])
- self.assertNotIn("TEST_NOT_EXISTS", os.environ)
+ assert "BEFORE" == os.environ["TEST_EXISTS"]
+ assert "TEST_NOT_EXISTS" not in os.environ
with suppress(AirflowException):
with process_utils.patch_environ({"TEST_NOT_EXISTS": "AFTER", "TEST_EXISTS": "AFTER"}):
- self.assertEqual("AFTER", os.environ["TEST_NOT_EXISTS"])
- self.assertEqual("AFTER", os.environ["TEST_EXISTS"])
+ assert "AFTER" == os.environ["TEST_NOT_EXISTS"]
+ assert "AFTER" == os.environ["TEST_EXISTS"]
raise AirflowException("Unknown exception")
- self.assertEqual("BEFORE", os.environ["TEST_EXISTS"])
- self.assertNotIn("TEST_NOT_EXISTS", os.environ)
+ assert "BEFORE" == os.environ["TEST_EXISTS"]
+ assert "TEST_NOT_EXISTS" not in os.environ
class TestCheckIfPidfileProcessIsRunning(unittest.TestCase):
@@ -193,7 +193,7 @@ def test_ok_if_no_file(self):
def test_remove_if_no_process(self):
# Assert file is deleted
- with self.assertRaises(FileNotFoundError):
+ with pytest.raises(FileNotFoundError):
with NamedTemporaryFile('+w') as f:
f.write('19191919191919191991')
f.flush()
@@ -204,5 +204,5 @@ def test_raise_error_if_process_is_running(self):
with NamedTemporaryFile('+w') as f:
f.write(str(pid))
f.flush()
- with self.assertRaisesRegex(AirflowException, "is already running under PID"):
+ with pytest.raises(AirflowException, match="is already running under PID"):
check_if_pidfile_process_is_running(f.name, process_name="test")
diff --git a/tests/utils/test_python_virtualenv.py b/tests/utils/test_python_virtualenv.py
index 59ad504b84e1f..1e4c10fe02236 100644
--- a/tests/utils/test_python_virtualenv.py
+++ b/tests/utils/test_python_virtualenv.py
@@ -28,7 +28,7 @@ def test_should_create_virtualenv(self, mock_execute_in_subprocess):
python_bin = prepare_virtualenv(
venv_directory="/VENV", python_bin="pythonVER", system_site_packages=False, requirements=[]
)
- self.assertEqual("/VENV/bin/python", python_bin)
+ assert "/VENV/bin/python" == python_bin
mock_execute_in_subprocess.assert_called_once_with(['virtualenv', '/VENV', '--python=pythonVER'])
@mock.patch('airflow.utils.python_virtualenv.execute_in_subprocess')
@@ -36,7 +36,7 @@ def test_should_create_virtualenv_with_system_packages(self, mock_execute_in_sub
python_bin = prepare_virtualenv(
venv_directory="/VENV", python_bin="pythonVER", system_site_packages=True, requirements=[]
)
- self.assertEqual("/VENV/bin/python", python_bin)
+ assert "/VENV/bin/python" == python_bin
mock_execute_in_subprocess.assert_called_once_with(
['virtualenv', '/VENV', '--system-site-packages', '--python=pythonVER']
)
@@ -49,7 +49,7 @@ def test_should_create_virtualenv_with_extra_packages(self, mock_execute_in_subp
system_site_packages=False,
requirements=['apache-beam[gcp]'],
)
- self.assertEqual("/VENV/bin/python", python_bin)
+ assert "/VENV/bin/python" == python_bin
mock_execute_in_subprocess.assert_any_call(['virtualenv', '/VENV', '--python=pythonVER'])
diff --git a/tests/utils/test_serve_logs.py b/tests/utils/test_serve_logs.py
index 7f687daad5769..edb4d0991b3e6 100644
--- a/tests/utils/test_serve_logs.py
+++ b/tests/utils/test_serve_logs.py
@@ -42,5 +42,5 @@ def test_should_serve_file(self):
sub_proc.start()
sleep(1)
log_url = f"http://localhost:{log_port}/log/{basename(f.name)}"
- self.assertEqual(LOG_DATA, requests.get(log_url).content.decode())
+ assert LOG_DATA == requests.get(log_url).content.decode()
sub_proc.terminate()
diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py
index 9d83cd77ffc7b..816f147bba7e8 100644
--- a/tests/utils/test_sqlalchemy.py
+++ b/tests/utils/test_sqlalchemy.py
@@ -67,14 +67,14 @@ def test_utc_transformations(self):
session=self.session,
)
- self.assertEqual(execution_date, run.execution_date)
- self.assertEqual(start_date, run.start_date)
+ assert execution_date == run.execution_date
+ assert start_date == run.start_date
- self.assertEqual(execution_date.utcoffset().total_seconds(), 0.0)
- self.assertEqual(start_date.utcoffset().total_seconds(), 0.0)
+ assert execution_date.utcoffset().total_seconds() == 0.0
+ assert start_date.utcoffset().total_seconds() == 0.0
- self.assertEqual(iso_date, run.run_id)
- self.assertEqual(run.start_date.isoformat(), run.run_id)
+ assert iso_date == run.run_id
+ assert run.start_date.isoformat() == run.run_id
dag.clear()
@@ -89,7 +89,7 @@ def test_process_bind_param_naive(self):
dag = DAG(dag_id=dag_id, start_date=start_date)
dag.clear()
- with self.assertRaises((ValueError, StatementError)):
+ with pytest.raises((ValueError, StatementError)):
dag.create_dagrun(
run_id=start_date.isoformat,
state=State.NONE,
@@ -127,7 +127,7 @@ def test_skip_locked(self, dialect, supports_for_update_of, expected_return_valu
session = mock.Mock()
session.bind.dialect.name = dialect
session.bind.dialect.supports_for_update_of = supports_for_update_of
- self.assertEqual(skip_locked(session=session), expected_return_value)
+ assert skip_locked(session=session) == expected_return_value
@parameterized.expand(
[
@@ -159,7 +159,7 @@ def test_nowait(self, dialect, supports_for_update_of, expected_return_value):
session = mock.Mock()
session.bind.dialect.name = dialect
session.bind.dialect.supports_for_update_of = supports_for_update_of
- self.assertEqual(nowait(session=session), expected_return_value)
+ assert nowait(session=session) == expected_return_value
def test_prohibit_commit(self):
with prohibit_commit(self.session) as guard:
diff --git a/tests/utils/test_task_handler_with_custom_formatter.py b/tests/utils/test_task_handler_with_custom_formatter.py
index 458d5c7a1861b..27724c8c38d04 100644
--- a/tests/utils/test_task_handler_with_custom_formatter.py
+++ b/tests/utils/test_task_handler_with_custom_formatter.py
@@ -58,9 +58,9 @@ def test_formatter(self):
logger = ti.log
ti.log.disabled = False
handler = next((handler for handler in logger.handlers if handler.name == TASK_HANDLER), None)
- self.assertIsNotNone(handler)
+ assert handler is not None
# setting the expected value of the formatter
expected_formatter_value = "test_dag-test_task:" + handler.formatter._fmt
set_context(logger, ti)
- self.assertEqual(expected_formatter_value, handler.formatter._fmt)
+ assert expected_formatter_value == handler.formatter._fmt
diff --git a/tests/utils/test_timezone.py b/tests/utils/test_timezone.py
index d7249bde1cdd4..c8dd1208280ad 100644
--- a/tests/utils/test_timezone.py
+++ b/tests/utils/test_timezone.py
@@ -20,6 +20,7 @@
import unittest
import pendulum
+import pytest
from airflow.utils import timezone
@@ -31,44 +32,41 @@
class TestTimezone(unittest.TestCase):
def test_is_aware(self):
- self.assertTrue(timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT)))
- self.assertFalse(timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30)))
+ assert timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT))
+ assert not timezone.is_localized(datetime.datetime(2011, 9, 1, 13, 20, 30))
def test_is_naive(self):
- self.assertFalse(timezone.is_naive(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT)))
- self.assertTrue(timezone.is_naive(datetime.datetime(2011, 9, 1, 13, 20, 30)))
+ assert not timezone.is_naive(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT))
+ assert timezone.is_naive(datetime.datetime(2011, 9, 1, 13, 20, 30))
def test_utcnow(self):
now = timezone.utcnow()
- self.assertTrue(timezone.is_localized(now))
- self.assertEqual(now.replace(tzinfo=None), now.astimezone(UTC).replace(tzinfo=None))
+ assert timezone.is_localized(now)
+ assert now.replace(tzinfo=None) == now.astimezone(UTC).replace(tzinfo=None)
def test_convert_to_utc(self):
naive = datetime.datetime(2011, 9, 1, 13, 20, 30)
utc = datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=UTC)
- self.assertEqual(utc, timezone.convert_to_utc(naive))
+ assert utc == timezone.convert_to_utc(naive)
eat = datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT)
utc = datetime.datetime(2011, 9, 1, 10, 20, 30, tzinfo=UTC)
- self.assertEqual(utc, timezone.convert_to_utc(eat))
+ assert utc == timezone.convert_to_utc(eat)
def test_make_naive(self):
- self.assertEqual(
- timezone.make_naive(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT), EAT),
- datetime.datetime(2011, 9, 1, 13, 20, 30),
- )
- self.assertEqual(
- timezone.make_naive(datetime.datetime(2011, 9, 1, 17, 20, 30, tzinfo=ICT), EAT),
- datetime.datetime(2011, 9, 1, 13, 20, 30),
- )
+ assert timezone.make_naive(
+ datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT), EAT
+ ) == datetime.datetime(2011, 9, 1, 13, 20, 30)
+ assert timezone.make_naive(
+ datetime.datetime(2011, 9, 1, 17, 20, 30, tzinfo=ICT), EAT
+ ) == datetime.datetime(2011, 9, 1, 13, 20, 30)
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
timezone.make_naive(datetime.datetime(2011, 9, 1, 13, 20, 30), EAT)
def test_make_aware(self):
- self.assertEqual(
- timezone.make_aware(datetime.datetime(2011, 9, 1, 13, 20, 30), EAT),
- datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT),
+ assert timezone.make_aware(datetime.datetime(2011, 9, 1, 13, 20, 30), EAT) == datetime.datetime(
+ 2011, 9, 1, 13, 20, 30, tzinfo=EAT
)
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
timezone.make_aware(datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT), EAT)
diff --git a/tests/utils/test_trigger_rule.py b/tests/utils/test_trigger_rule.py
index afbe0fcf48a37..be7a903afb32b 100644
--- a/tests/utils/test_trigger_rule.py
+++ b/tests/utils/test_trigger_rule.py
@@ -23,13 +23,13 @@
class TestTriggerRule(unittest.TestCase):
def test_valid_trigger_rules(self):
- self.assertTrue(TriggerRule.is_valid(TriggerRule.ALL_SUCCESS))
- self.assertTrue(TriggerRule.is_valid(TriggerRule.ALL_FAILED))
- self.assertTrue(TriggerRule.is_valid(TriggerRule.ALL_DONE))
- self.assertTrue(TriggerRule.is_valid(TriggerRule.ONE_SUCCESS))
- self.assertTrue(TriggerRule.is_valid(TriggerRule.ONE_FAILED))
- self.assertTrue(TriggerRule.is_valid(TriggerRule.NONE_FAILED))
- self.assertTrue(TriggerRule.is_valid(TriggerRule.NONE_FAILED_OR_SKIPPED))
- self.assertTrue(TriggerRule.is_valid(TriggerRule.NONE_SKIPPED))
- self.assertTrue(TriggerRule.is_valid(TriggerRule.DUMMY))
- self.assertEqual(len(TriggerRule.all_triggers()), 9)
+ assert TriggerRule.is_valid(TriggerRule.ALL_SUCCESS)
+ assert TriggerRule.is_valid(TriggerRule.ALL_FAILED)
+ assert TriggerRule.is_valid(TriggerRule.ALL_DONE)
+ assert TriggerRule.is_valid(TriggerRule.ONE_SUCCESS)
+ assert TriggerRule.is_valid(TriggerRule.ONE_FAILED)
+ assert TriggerRule.is_valid(TriggerRule.NONE_FAILED)
+ assert TriggerRule.is_valid(TriggerRule.NONE_FAILED_OR_SKIPPED)
+ assert TriggerRule.is_valid(TriggerRule.NONE_SKIPPED)
+ assert TriggerRule.is_valid(TriggerRule.DUMMY)
+ assert len(TriggerRule.all_triggers()) == 9
diff --git a/tests/utils/test_weekday.py b/tests/utils/test_weekday.py
index e679ed295e860..5d0dcb03a869d 100644
--- a/tests/utils/test_weekday.py
+++ b/tests/utils/test_weekday.py
@@ -23,18 +23,18 @@
class TestWeekDay(unittest.TestCase):
def test_weekday_enum_length(self):
- self.assertEqual(len(WeekDay), 7)
+ assert len(WeekDay) == 7
def test_weekday_name_value(self):
weekdays = "MONDAY TUESDAY WEDNESDAY THURSDAY FRIDAY SATURDAY SUNDAY"
weekdays = weekdays.split()
for i, weekday in enumerate(weekdays, start=1):
weekday_enum = WeekDay(i)
- self.assertEqual(weekday_enum, i)
- self.assertEqual(int(weekday_enum), i)
- self.assertEqual(weekday_enum.name, weekday)
- self.assertTrue(weekday_enum in WeekDay)
- self.assertTrue(0 < weekday_enum < 8)
- self.assertIsInstance(weekday_enum, WeekDay)
- self.assertIsInstance(weekday_enum, int)
- self.assertIsInstance(weekday_enum, Enum)
+ assert weekday_enum == i
+ assert int(weekday_enum) == i
+ assert weekday_enum.name == weekday
+ assert weekday_enum in WeekDay
+ assert 0 < weekday_enum < 8
+ assert isinstance(weekday_enum, WeekDay)
+ assert isinstance(weekday_enum, int)
+ assert isinstance(weekday_enum, Enum)
diff --git a/tests/utils/test_weight_rule.py b/tests/utils/test_weight_rule.py
index 862e1b426f7fa..cad142de47d19 100644
--- a/tests/utils/test_weight_rule.py
+++ b/tests/utils/test_weight_rule.py
@@ -23,7 +23,7 @@
class TestWeightRule(unittest.TestCase):
def test_valid_weight_rules(self):
- self.assertTrue(WeightRule.is_valid(WeightRule.DOWNSTREAM))
- self.assertTrue(WeightRule.is_valid(WeightRule.UPSTREAM))
- self.assertTrue(WeightRule.is_valid(WeightRule.ABSOLUTE))
- self.assertEqual(len(WeightRule.all_weight_rules()), 3)
+ assert WeightRule.is_valid(WeightRule.DOWNSTREAM)
+ assert WeightRule.is_valid(WeightRule.UPSTREAM)
+ assert WeightRule.is_valid(WeightRule.ABSOLUTE)
+ assert len(WeightRule.all_weight_rules()) == 3
diff --git a/tests/www/api/experimental/test_dag_runs_endpoint.py b/tests/www/api/experimental/test_dag_runs_endpoint.py
index b094c017fb4df..24190a535ceff 100644
--- a/tests/www/api/experimental/test_dag_runs_endpoint.py
+++ b/tests/www/api/experimental/test_dag_runs_endpoint.py
@@ -63,13 +63,13 @@ def test_get_dag_runs_success(self):
dag_run = trigger_dag(dag_id=dag_id, run_id='test_get_dag_runs_success')
response = self.app.get(url_template.format(dag_id))
- self.assertEqual(200, response.status_code)
+ assert 200 == response.status_code
data = json.loads(response.data.decode('utf-8'))
- self.assertIsInstance(data, list)
- self.assertEqual(len(data), 1)
- self.assertEqual(data[0]['dag_id'], dag_id)
- self.assertEqual(data[0]['id'], dag_run.id)
+ assert isinstance(data, list)
+ assert len(data) == 1
+ assert data[0]['dag_id'] == dag_id
+ assert data[0]['id'] == dag_run.id
def test_get_dag_runs_success_with_state_parameter(self):
url_template = '/api/experimental/dags/{}/dag_runs?state=running'
@@ -78,13 +78,13 @@ def test_get_dag_runs_success_with_state_parameter(self):
dag_run = trigger_dag(dag_id=dag_id, run_id='test_get_dag_runs_success')
response = self.app.get(url_template.format(dag_id))
- self.assertEqual(200, response.status_code)
+ assert 200 == response.status_code
data = json.loads(response.data.decode('utf-8'))
- self.assertIsInstance(data, list)
- self.assertEqual(len(data), 1)
- self.assertEqual(data[0]['dag_id'], dag_id)
- self.assertEqual(data[0]['id'], dag_run.id)
+ assert isinstance(data, list)
+ assert len(data) == 1
+ assert data[0]['dag_id'] == dag_id
+ assert data[0]['id'] == dag_run.id
def test_get_dag_runs_success_with_capital_state_parameter(self):
url_template = '/api/experimental/dags/{}/dag_runs?state=RUNNING'
@@ -93,13 +93,13 @@ def test_get_dag_runs_success_with_capital_state_parameter(self):
dag_run = trigger_dag(dag_id=dag_id, run_id='test_get_dag_runs_success')
response = self.app.get(url_template.format(dag_id))
- self.assertEqual(200, response.status_code)
+ assert 200 == response.status_code
data = json.loads(response.data.decode('utf-8'))
- self.assertIsInstance(data, list)
- self.assertEqual(len(data), 1)
- self.assertEqual(data[0]['dag_id'], dag_id)
- self.assertEqual(data[0]['id'], dag_run.id)
+ assert isinstance(data, list)
+ assert len(data) == 1
+ assert data[0]['dag_id'] == dag_id
+ assert data[0]['id'] == dag_run.id
def test_get_dag_runs_success_with_state_no_result(self):
url_template = '/api/experimental/dags/{}/dag_runs?state=dummy'
@@ -108,29 +108,29 @@ def test_get_dag_runs_success_with_state_no_result(self):
trigger_dag(dag_id=dag_id, run_id='test_get_dag_runs_success')
response = self.app.get(url_template.format(dag_id))
- self.assertEqual(200, response.status_code)
+ assert 200 == response.status_code
data = json.loads(response.data.decode('utf-8'))
- self.assertIsInstance(data, list)
- self.assertEqual(len(data), 0)
+ assert isinstance(data, list)
+ assert len(data) == 0
def test_get_dag_runs_invalid_dag_id(self):
url_template = '/api/experimental/dags/{}/dag_runs'
dag_id = 'DUMMY_DAG'
response = self.app.get(url_template.format(dag_id))
- self.assertEqual(400, response.status_code)
+ assert 400 == response.status_code
data = json.loads(response.data.decode('utf-8'))
- self.assertNotIsInstance(data, list)
+ assert not isinstance(data, list)
def test_get_dag_runs_no_runs(self):
url_template = '/api/experimental/dags/{}/dag_runs'
dag_id = 'example_bash_operator'
response = self.app.get(url_template.format(dag_id))
- self.assertEqual(200, response.status_code)
+ assert 200 == response.status_code
data = json.loads(response.data.decode('utf-8'))
- self.assertIsInstance(data, list)
- self.assertEqual(len(data), 0)
+ assert isinstance(data, list)
+ assert len(data) == 0
diff --git a/tests/www/api/experimental/test_endpoints.py b/tests/www/api/experimental/test_endpoints.py
index be569985e9b26..5981eacccfc0e 100644
--- a/tests/www/api/experimental/test_endpoints.py
+++ b/tests/www/api/experimental/test_endpoints.py
@@ -17,6 +17,7 @@
# under the License.
import json
import os
+import re
import unittest
from datetime import timedelta
from unittest import mock
@@ -52,10 +53,10 @@ def setUp(self):
self.session = Session
def assert_deprecated(self, resp):
- self.assertEqual('true', resp.headers['Deprecation'])
- self.assertRegex(
- resp.headers['Link'],
+ assert 'true' == resp.headers['Deprecation']
+ assert re.search(
r'\<.+/stable-rest-api/migration.html\>; ' 'rel="deprecation"; type="text/html"',
+ resp.headers['Link'],
)
@@ -88,7 +89,7 @@ def test_info(self):
resp_raw = self.client.get(url)
resp = json.loads(resp_raw.data.decode('utf-8'))
- self.assertEqual(version, resp['version'])
+ assert version == resp['version']
self.assert_deprecated(resp_raw)
def test_task_info(self):
@@ -97,28 +98,28 @@ def test_task_info(self):
response = self.client.get(url_template.format('example_bash_operator', 'runme_0'))
self.assert_deprecated(response)
- self.assertIn('"email"', response.data.decode('utf-8'))
- self.assertNotIn('error', response.data.decode('utf-8'))
- self.assertEqual(200, response.status_code)
+ assert '"email"' in response.data.decode('utf-8')
+ assert 'error' not in response.data.decode('utf-8')
+ assert 200 == response.status_code
response = self.client.get(url_template.format('example_bash_operator', 'DNE'))
- self.assertIn('error', response.data.decode('utf-8'))
- self.assertEqual(404, response.status_code)
+ assert 'error' in response.data.decode('utf-8')
+ assert 404 == response.status_code
response = self.client.get(url_template.format('DNE', 'DNE'))
- self.assertIn('error', response.data.decode('utf-8'))
- self.assertEqual(404, response.status_code)
+ assert 'error' in response.data.decode('utf-8')
+ assert 404 == response.status_code
def test_get_dag_code(self):
url_template = '/api/experimental/dags/{}/code'
response = self.client.get(url_template.format('example_bash_operator'))
self.assert_deprecated(response)
- self.assertIn('BashOperator(', response.data.decode('utf-8'))
- self.assertEqual(200, response.status_code)
+ assert 'BashOperator(' in response.data.decode('utf-8')
+ assert 200 == response.status_code
response = self.client.get(url_template.format('xyz'))
- self.assertEqual(404, response.status_code)
+ assert 404 == response.status_code
def test_dag_paused(self):
pause_url_template = '/api/experimental/dags/{}/paused/{}'
@@ -127,22 +128,22 @@ def test_dag_paused(self):
response = self.client.get(pause_url_template.format('example_bash_operator', 'true'))
self.assert_deprecated(response)
- self.assertIn('ok', response.data.decode('utf-8'))
- self.assertEqual(200, response.status_code)
+ assert 'ok' in response.data.decode('utf-8')
+ assert 200 == response.status_code
paused_response = self.client.get(paused_url)
- self.assertEqual(200, paused_response.status_code)
- self.assertEqual({"is_paused": True}, paused_response.json)
+ assert 200 == paused_response.status_code
+ assert {"is_paused": True} == paused_response.json
response = self.client.get(pause_url_template.format('example_bash_operator', 'false'))
- self.assertIn('ok', response.data.decode('utf-8'))
- self.assertEqual(200, response.status_code)
+ assert 'ok' in response.data.decode('utf-8')
+ assert 200 == response.status_code
paused_response = self.client.get(paused_url)
- self.assertEqual(200, paused_response.status_code)
- self.assertEqual({"is_paused": False}, paused_response.json)
+ assert 200 == paused_response.status_code
+ assert {"is_paused": False} == paused_response.json
def test_trigger_dag(self):
url_template = '/api/experimental/dags/{}/dag_runs'
@@ -154,9 +155,9 @@ def test_trigger_dag(self):
)
self.assert_deprecated(response)
- self.assertEqual(200, response.status_code)
+ assert 200 == response.status_code
response_execution_date = parse_datetime(json.loads(response.data.decode('utf-8'))['execution_date'])
- self.assertEqual(0, response_execution_date.microsecond)
+ assert 0 == response_execution_date.microsecond
# Check execution_date is correct
response = json.loads(response.data.decode('utf-8'))
@@ -164,14 +165,14 @@ def test_trigger_dag(self):
dag = dagbag.get_dag('example_bash_operator')
dag_run = dag.get_dagrun(response_execution_date)
dag_run_id = dag_run.run_id
- self.assertEqual(run_id, dag_run_id)
- self.assertEqual(dag_run_id, response['run_id'])
+ assert run_id == dag_run_id
+ assert dag_run_id == response['run_id']
# Test error for nonexistent dag
response = self.client.post(
url_template.format('does_not_exist_dag'), data=json.dumps({}), content_type="application/json"
)
- self.assertEqual(404, response.status_code)
+ assert 404 == response.status_code
def test_trigger_dag_for_date(self):
url_template = '/api/experimental/dags/{}/dag_runs'
@@ -186,13 +187,13 @@ def test_trigger_dag_for_date(self):
content_type="application/json",
)
self.assert_deprecated(response)
- self.assertEqual(200, response.status_code)
- self.assertEqual(datetime_string, json.loads(response.data.decode('utf-8'))['execution_date'])
+ assert 200 == response.status_code
+ assert datetime_string == json.loads(response.data.decode('utf-8'))['execution_date']
dagbag = DagBag()
dag = dagbag.get_dag(dag_id)
dag_run = dag.get_dagrun(execution_date)
- self.assertTrue(dag_run, f'Dag Run not found for execution date {execution_date}')
+ assert dag_run, f'Dag Run not found for execution date {execution_date}'
# Test correct execution with execution date and microseconds replaced
response = self.client.post(
@@ -200,14 +201,14 @@ def test_trigger_dag_for_date(self):
data=json.dumps({'execution_date': datetime_string, 'replace_microseconds': 'true'}),
content_type="application/json",
)
- self.assertEqual(200, response.status_code)
+ assert 200 == response.status_code
response_execution_date = parse_datetime(json.loads(response.data.decode('utf-8'))['execution_date'])
- self.assertEqual(0, response_execution_date.microsecond)
+ assert 0 == response_execution_date.microsecond
dagbag = DagBag()
dag = dagbag.get_dag(dag_id)
dag_run = dag.get_dagrun(response_execution_date)
- self.assertTrue(dag_run, f'Dag Run not found for execution date {execution_date}')
+ assert dag_run, f'Dag Run not found for execution date {execution_date}'
# Test error for nonexistent dag
response = self.client.post(
@@ -215,7 +216,7 @@ def test_trigger_dag_for_date(self):
data=json.dumps({'execution_date': datetime_string}),
content_type="application/json",
)
- self.assertEqual(404, response.status_code)
+ assert 404 == response.status_code
# Test error for bad datetime format
response = self.client.post(
@@ -223,7 +224,7 @@ def test_trigger_dag_for_date(self):
data=json.dumps({'execution_date': 'not_a_datetime'}),
content_type="application/json",
)
- self.assertEqual(400, response.status_code)
+ assert 400 == response.status_code
def test_task_instance_info(self):
url_template = '/api/experimental/dags/{}/dag_runs/{}/tasks/{}'
@@ -239,31 +240,31 @@ def test_task_instance_info(self):
# Test Correct execution
response = self.client.get(url_template.format(dag_id, datetime_string, task_id))
self.assert_deprecated(response)
- self.assertEqual(200, response.status_code)
- self.assertIn('state', response.data.decode('utf-8'))
- self.assertNotIn('error', response.data.decode('utf-8'))
+ assert 200 == response.status_code
+ assert 'state' in response.data.decode('utf-8')
+ assert 'error' not in response.data.decode('utf-8')
# Test error for nonexistent dag
response = self.client.get(
url_template.format('does_not_exist_dag', datetime_string, task_id),
)
- self.assertEqual(404, response.status_code)
- self.assertIn('error', response.data.decode('utf-8'))
+ assert 404 == response.status_code
+ assert 'error' in response.data.decode('utf-8')
# Test error for nonexistent task
response = self.client.get(url_template.format(dag_id, datetime_string, 'does_not_exist_task'))
- self.assertEqual(404, response.status_code)
- self.assertIn('error', response.data.decode('utf-8'))
+ assert 404 == response.status_code
+ assert 'error' in response.data.decode('utf-8')
# Test error for nonexistent dag run (wrong execution_date)
response = self.client.get(url_template.format(dag_id, wrong_datetime_string, task_id))
- self.assertEqual(404, response.status_code)
- self.assertIn('error', response.data.decode('utf-8'))
+ assert 404 == response.status_code
+ assert 'error' in response.data.decode('utf-8')
# Test error for bad datetime format
response = self.client.get(url_template.format(dag_id, 'not_a_datetime', task_id))
- self.assertEqual(400, response.status_code)
- self.assertIn('error', response.data.decode('utf-8'))
+ assert 400 == response.status_code
+ assert 'error' in response.data.decode('utf-8')
def test_dagrun_status(self):
url_template = '/api/experimental/dags/{}/dag_runs/{}'
@@ -278,26 +279,26 @@ def test_dagrun_status(self):
# Test Correct execution
response = self.client.get(url_template.format(dag_id, datetime_string))
self.assert_deprecated(response)
- self.assertEqual(200, response.status_code)
- self.assertIn('state', response.data.decode('utf-8'))
- self.assertNotIn('error', response.data.decode('utf-8'))
+ assert 200 == response.status_code
+ assert 'state' in response.data.decode('utf-8')
+ assert 'error' not in response.data.decode('utf-8')
# Test error for nonexistent dag
response = self.client.get(
url_template.format('does_not_exist_dag', datetime_string),
)
- self.assertEqual(404, response.status_code)
- self.assertIn('error', response.data.decode('utf-8'))
+ assert 404 == response.status_code
+ assert 'error' in response.data.decode('utf-8')
# Test error for nonexistent dag run (wrong execution_date)
response = self.client.get(url_template.format(dag_id, wrong_datetime_string))
- self.assertEqual(404, response.status_code)
- self.assertIn('error', response.data.decode('utf-8'))
+ assert 404 == response.status_code
+ assert 'error' in response.data.decode('utf-8')
# Test error for bad datetime format
response = self.client.get(url_template.format(dag_id, 'not_a_datetime'))
- self.assertEqual(400, response.status_code)
- self.assertIn('error', response.data.decode('utf-8'))
+ assert 400 == response.status_code
+ assert 'error' in response.data.decode('utf-8')
class TestLineageApiExperimental(TestBase):
@@ -331,26 +332,26 @@ def test_lineage_info(self):
# test correct execution
response = self.client.get(url_template.format(dag_id, datetime_string))
self.assert_deprecated(response)
- self.assertEqual(200, response.status_code)
- self.assertIn('task_ids', response.data.decode('utf-8'))
- self.assertNotIn('error', response.data.decode('utf-8'))
+ assert 200 == response.status_code
+ assert 'task_ids' in response.data.decode('utf-8')
+ assert 'error' not in response.data.decode('utf-8')
# Test error for nonexistent dag
response = self.client.get(
url_template.format('does_not_exist_dag', datetime_string),
)
- self.assertEqual(404, response.status_code)
- self.assertIn('error', response.data.decode('utf-8'))
+ assert 404 == response.status_code
+ assert 'error' in response.data.decode('utf-8')
# Test error for nonexistent dag run (wrong execution_date)
response = self.client.get(url_template.format(dag_id, wrong_datetime_string))
- self.assertEqual(404, response.status_code)
- self.assertIn('error', response.data.decode('utf-8'))
+ assert 404 == response.status_code
+ assert 'error' in response.data.decode('utf-8')
# Test error for bad datetime format
response = self.client.get(url_template.format(dag_id, 'not_a_datetime'))
- self.assertEqual(400, response.status_code)
- self.assertIn('error', response.data.decode('utf-8'))
+ assert 400 == response.status_code
+ assert 'error' in response.data.decode('utf-8')
class TestPoolApiExperimental(TestBase):
@@ -380,7 +381,7 @@ def setUp(self):
def _get_pool_count(self):
response = self.client.get('/api/experimental/pools')
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
return len(json.loads(response.data.decode('utf-8')))
def test_get_pool(self):
@@ -388,22 +389,22 @@ def test_get_pool(self):
f'/api/experimental/pools/{self.pool.pool}',
)
self.assert_deprecated(response)
- self.assertEqual(response.status_code, 200)
- self.assertEqual(json.loads(response.data.decode('utf-8')), self.pool.to_json())
+ assert response.status_code == 200
+ assert json.loads(response.data.decode('utf-8')) == self.pool.to_json()
def test_get_pool_non_existing(self):
response = self.client.get('/api/experimental/pools/foo')
- self.assertEqual(response.status_code, 404)
- self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], "Pool 'foo' doesn't exist")
+ assert response.status_code == 404
+ assert json.loads(response.data.decode('utf-8'))['error'] == "Pool 'foo' doesn't exist"
def test_get_pools(self):
response = self.client.get('/api/experimental/pools')
self.assert_deprecated(response)
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
pools = json.loads(response.data.decode('utf-8'))
- self.assertEqual(len(pools), self.TOTAL_POOL_COUNT)
+ assert len(pools) == self.TOTAL_POOL_COUNT
for i, pool in enumerate(sorted(pools, key=lambda p: p['pool'])):
- self.assertDictEqual(pool, self.pools[i].to_json())
+ assert pool == self.pools[i].to_json()
def test_create_pool(self):
response = self.client.post(
@@ -418,12 +419,12 @@ def test_create_pool(self):
content_type='application/json',
)
self.assert_deprecated(response)
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
pool = json.loads(response.data.decode('utf-8'))
- self.assertEqual(pool['pool'], 'foo')
- self.assertEqual(pool['slots'], 1)
- self.assertEqual(pool['description'], '')
- self.assertEqual(self._get_pool_count(), self.TOTAL_POOL_COUNT + 1)
+ assert pool['pool'] == 'foo'
+ assert pool['slots'] == 1
+ assert pool['description'] == ''
+ assert self._get_pool_count() == self.TOTAL_POOL_COUNT + 1
def test_create_pool_with_bad_name(self):
for name in ('', ' '):
@@ -438,33 +439,30 @@ def test_create_pool_with_bad_name(self):
),
content_type='application/json',
)
- self.assertEqual(response.status_code, 400)
- self.assertEqual(
- json.loads(response.data.decode('utf-8'))['error'],
- "Pool name shouldn't be empty",
- )
- self.assertEqual(self._get_pool_count(), self.TOTAL_POOL_COUNT)
+ assert response.status_code == 400
+ assert json.loads(response.data.decode('utf-8'))['error'] == "Pool name shouldn't be empty"
+ assert self._get_pool_count() == self.TOTAL_POOL_COUNT
def test_delete_pool(self):
response = self.client.delete(
f'/api/experimental/pools/{self.pool.pool}',
)
self.assert_deprecated(response)
- self.assertEqual(response.status_code, 200)
- self.assertEqual(json.loads(response.data.decode('utf-8')), self.pool.to_json())
- self.assertEqual(self._get_pool_count(), self.TOTAL_POOL_COUNT - 1)
+ assert response.status_code == 200
+ assert json.loads(response.data.decode('utf-8')) == self.pool.to_json()
+ assert self._get_pool_count() == self.TOTAL_POOL_COUNT - 1
def test_delete_pool_non_existing(self):
response = self.client.delete(
'/api/experimental/pools/foo',
)
- self.assertEqual(response.status_code, 404)
- self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], "Pool 'foo' doesn't exist")
+ assert response.status_code == 404
+ assert json.loads(response.data.decode('utf-8'))['error'] == "Pool 'foo' doesn't exist"
def test_delete_default_pool(self):
clear_db_pools()
response = self.client.delete(
'/api/experimental/pools/default_pool',
)
- self.assertEqual(response.status_code, 400)
- self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], "default_pool cannot be deleted")
+ assert response.status_code == 400
+ assert json.loads(response.data.decode('utf-8'))['error'] == "default_pool cannot be deleted"
diff --git a/tests/www/test_app.py b/tests/www/test_app.py
index 5240d9436acb9..b731db57551a4 100644
--- a/tests/www/test_app.py
+++ b/tests/www/test_app.py
@@ -55,10 +55,10 @@ def debug_view():
from flask import request
# Should respect HTTP_X_FORWARDED_FOR
- self.assertEqual(request.remote_addr, '192.168.0.1')
+ assert request.remote_addr == '192.168.0.1'
# Should respect HTTP_X_FORWARDED_PROTO, HTTP_X_FORWARDED_HOST, HTTP_X_FORWARDED_PORT,
# HTTP_X_FORWARDED_PREFIX
- self.assertEqual(request.url, 'https://valid:445/proxy-prefix/debug')
+ assert request.url == 'https://valid:445/proxy-prefix/debug'
return Response("success")
@@ -78,8 +78,8 @@ def debug_view():
response = Response.from_app(app, environ)
- self.assertEqual(b"success", response.get_data())
- self.assertEqual(response.status_code, 200)
+ assert b"success" == response.get_data()
+ assert response.status_code == 200
@conf_vars(
{
@@ -95,10 +95,10 @@ def debug_view():
from flask import request
# Should ignore HTTP_X_FORWARDED_FOR
- self.assertEqual(request.remote_addr, '192.168.0.2')
+ assert request.remote_addr == '192.168.0.2'
# Should ignore HTTP_X_FORWARDED_PROTO, HTTP_X_FORWARDED_HOST, HTTP_X_FORWARDED_PORT,
# HTTP_X_FORWARDED_PREFIX
- self.assertEqual(request.url, 'http://invalid:9000/internal-client/debug')
+ assert request.url == 'http://invalid:9000/internal-client/debug'
return Response("success")
@@ -118,8 +118,8 @@ def debug_view():
response = Response.from_app(app, environ)
- self.assertEqual(b"success", response.get_data())
- self.assertEqual(response.status_code, 200)
+ assert b"success" == response.get_data()
+ assert response.status_code == 200
@conf_vars(
{
@@ -141,9 +141,9 @@ def debug_view():
from flask import request
# Should use original REMOTE_ADDR
- self.assertEqual(request.remote_addr, '192.168.0.1')
+ assert request.remote_addr == '192.168.0.1'
# Should respect base_url
- self.assertEqual(request.url, "http://invalid:9000/internal-client/debug")
+ assert request.url == "http://invalid:9000/internal-client/debug"
return Response("success")
@@ -158,8 +158,8 @@ def debug_view():
response = Response.from_app(app, environ)
- self.assertEqual(b"success", response.get_data())
- self.assertEqual(response.status_code, 200)
+ assert b"success" == response.get_data()
+ assert response.status_code == 200
@conf_vars(
{
@@ -181,10 +181,10 @@ def debug_view():
from flask import request
# Should respect HTTP_X_FORWARDED_FOR
- self.assertEqual(request.remote_addr, '192.168.0.1')
+ assert request.remote_addr == '192.168.0.1'
# Should respect HTTP_X_FORWARDED_PROTO, HTTP_X_FORWARDED_HOST, HTTP_X_FORWARDED_PORT,
# HTTP_X_FORWARDED_PREFIX and use base_url
- self.assertEqual(request.url, "https://valid:445/proxy-prefix/internal-client/debug")
+ assert request.url == "https://valid:445/proxy-prefix/internal-client/debug"
return Response("success")
@@ -204,8 +204,8 @@ def debug_view():
response = Response.from_app(app, environ)
- self.assertEqual(b"success", response.get_data())
- self.assertEqual(response.status_code, 200)
+ assert b"success" == response.get_data()
+ assert response.status_code == 200
@conf_vars(
{
@@ -221,7 +221,7 @@ def debug_view():
def test_should_set_sqlalchemy_engine_options(self):
app = application.cached_app(testing=True)
engine_params = {'pool_size': 3, 'pool_recycle': 120, 'pool_pre_ping': True, 'max_overflow': 5}
- self.assertEqual(app.config['SQLALCHEMY_ENGINE_OPTIONS'], engine_params)
+ assert app.config['SQLALCHEMY_ENGINE_OPTIONS'] == engine_params
@conf_vars(
{
@@ -231,11 +231,11 @@ def test_should_set_sqlalchemy_engine_options(self):
@mock.patch("airflow.www.app.app", None)
def test_should_set_permanent_session_timeout(self):
app = application.cached_app(testing=True)
- self.assertEqual(app.config['PERMANENT_SESSION_LIFETIME'], timedelta(minutes=3600))
+ assert app.config['PERMANENT_SESSION_LIFETIME'] == timedelta(minutes=3600)
class TestFlaskCli(unittest.TestCase):
def test_flask_cli_should_display_routes(self):
with mock.patch.dict("os.environ", FLASK_APP="airflow.www.app:create_app"):
output = subprocess.check_output(["flask", "routes"])
- self.assertIn("/api/v1/version", output.decode())
+ assert "/api/v1/version" in output.decode()
diff --git a/tests/www/test_init_views.py b/tests/www/test_init_views.py
index 12eab340ec670..e0624b6ae6178 100644
--- a/tests/www/test_init_views.py
+++ b/tests/www/test_init_views.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.www.extensions import init_views
from tests.test_utils.config import conf_vars
@@ -27,13 +29,12 @@ class TestInitApiExperimental(unittest.TestCase):
@conf_vars({('api', 'enable_experimental_api'): 'true'})
def test_should_raise_deprecation_warning_when_enabled(self):
app = mock.MagicMock()
- with self.assertWarnsRegex(DeprecationWarning, re.escape("The experimental REST API is deprecated.")):
+ with pytest.warns(DeprecationWarning, match=re.escape("The experimental REST API is deprecated.")):
init_views.init_api_experimental(app)
@conf_vars({('api', 'enable_experimental_api'): 'false'})
def test_should_not_raise_deprecation_warning_when_disabled(self):
app = mock.MagicMock()
- with self.assertRaises(AssertionError), self.assertWarnsRegex(
- DeprecationWarning, re.escape("The experimental REST API is deprecated.")
- ):
+ with pytest.warns(None) as warnings:
init_views.init_api_experimental(app)
+ assert len(warnings) == 0
diff --git a/tests/www/test_security.py b/tests/www/test_security.py
index 7916653f83e8d..a8ec939a49a5d 100644
--- a/tests/www/test_security.py
+++ b/tests/www/test_security.py
@@ -20,6 +20,7 @@
import unittest
from unittest import mock
+import pytest
from flask_appbuilder import SQLA, Model, expose, has_access
from flask_appbuilder.security.sqla import models as sqla_models
from flask_appbuilder.views import BaseView, ModelView
@@ -111,17 +112,13 @@ def expect_user_is_in_role(self, user, rolename):
def assert_user_has_dag_perms(self, perms, dag_id, user=None):
for perm in perms:
- self.assertTrue(
- self._has_dag_perm(perm, dag_id, user),
- f"User should have '{perm}' on DAG '{dag_id}'",
- )
+ assert self._has_dag_perm(perm, dag_id, user), f"User should have '{perm}' on DAG '{dag_id}'"
def assert_user_does_not_have_dag_perms(self, dag_id, perms, user=None):
for perm in perms:
- self.assertFalse(
- self._has_dag_perm(perm, dag_id, user),
- f"User should not have '{perm}' on DAG '{dag_id}'",
- )
+ assert not self._has_dag_perm(
+ perm, dag_id, user
+ ), f"User should not have '{perm}' on DAG '{dag_id}'"
def _has_dag_perm(self, perm, dag_id, user):
# if not user:
@@ -141,8 +138,8 @@ def test_init_role_baseview(self):
role_perms = [('can_some_action', 'SomeBaseView')]
self.security_manager.init_role(role_name, perms=role_perms)
role = self.appbuilder.sm.find_role(role_name)
- self.assertIsNotNone(role)
- self.assertEqual(len(role_perms), len(role.permissions))
+ assert role is not None
+ assert len(role_perms) == len(role.permissions)
def test_init_role_modelview(self):
role_name = 'MyRole2'
@@ -155,8 +152,8 @@ def test_init_role_modelview(self):
]
self.security_manager.init_role(role_name, role_perms)
role = self.appbuilder.sm.find_role(role_name)
- self.assertIsNotNone(role)
- self.assertEqual(len(role_perms), len(role.permissions))
+ assert role is not None
+ assert len(role_perms) == len(role.permissions)
def test_update_and_verify_permission_role(self):
role_name = 'Test_Role'
@@ -170,14 +167,14 @@ def test_update_and_verify_permission_role(self):
self.security_manager.init_role(role_name, [])
new_role_perms_len = len(role.permissions)
- self.assertEqual(role_perms_len, new_role_perms_len)
+ assert role_perms_len == new_role_perms_len
def test_get_user_roles(self):
user = mock.MagicMock()
user.is_anonymous = False
roles = self.appbuilder.sm.find_role('Admin')
user.roles = roles
- self.assertEqual(self.security_manager.get_user_roles(user), roles)
+ assert self.security_manager.get_user_roles(user) == roles
def test_get_user_roles_for_anonymous_user(self):
viewer_role_perms = {
@@ -226,7 +223,7 @@ def test_get_user_roles_for_anonymous_user(self):
perms_views.update(
{(perm_view.permission.name, perm_view.view_menu.name) for perm_view in role.permissions}
)
- self.assertEqual(perms_views, viewer_role_perms)
+ assert perms_views == viewer_role_perms
@mock.patch('airflow.www.security.AirflowSecurityManager.get_user_roles')
def test_get_all_permissions_views(self, mock_get_user_roles):
@@ -247,10 +244,10 @@ def test_get_all_permissions_views(self, mock_get_user_roles):
role = user.roles[0]
mock_get_user_roles.return_value = [role]
- self.assertEqual(self.security_manager.get_all_permissions_views(), {(role_perm, role_vm)})
+ assert self.security_manager.get_all_permissions_views() == {(role_perm, role_vm)}
mock_get_user_roles.return_value = []
- self.assertEqual(len(self.security_manager.get_all_permissions_views()), 0)
+ assert len(self.security_manager.get_all_permissions_views()) == 0
def test_get_accessible_dag_ids(self):
role_name = 'MyRole1'
@@ -276,7 +273,7 @@ def test_get_accessible_dag_ids(self):
dag_id, access_control={role_name: permission_action}
)
- self.assertEqual(self.security_manager.get_accessible_dag_ids(user), {'dag_id'})
+ assert self.security_manager.get_accessible_dag_ids(user) == {'dag_id'}
def test_dont_get_inaccessible_dag_ids_for_dag_resource_permission(self):
# In this test case,
@@ -303,48 +300,50 @@ def test_dont_get_inaccessible_dag_ids_for_dag_resource_permission(self):
dag_id, access_control={role_name: permission_action}
)
- self.assertEqual(self.security_manager.get_readable_dag_ids(user), set())
+ assert self.security_manager.get_readable_dag_ids(user) == set()
@mock.patch('airflow.www.security.AirflowSecurityManager._has_view_access')
def test_has_access(self, mock_has_view_access):
user = mock.MagicMock()
user.is_anonymous = False
mock_has_view_access.return_value = True
- self.assertTrue(self.security_manager.has_access('perm', 'view', user))
+ assert self.security_manager.has_access('perm', 'view', user)
def test_sync_perm_for_dag_creates_permissions_on_view_menus(self):
test_dag_id = 'TEST_DAG'
prefixed_test_dag_id = f'DAG:{test_dag_id}'
self.security_manager.sync_perm_for_dag(test_dag_id, access_control=None)
- self.assertIsNotNone(
+ assert (
self.security_manager.find_permission_view_menu(permissions.ACTION_CAN_READ, prefixed_test_dag_id)
+ is not None
)
- self.assertIsNotNone(
+ assert (
self.security_manager.find_permission_view_menu(permissions.ACTION_CAN_EDIT, prefixed_test_dag_id)
+ is not None
)
@mock.patch('airflow.www.security.AirflowSecurityManager._has_perm')
@mock.patch('airflow.www.security.AirflowSecurityManager._has_role')
def test_has_all_dag_access(self, mock_has_role, mock_has_perm):
mock_has_role.return_value = True
- self.assertTrue(self.security_manager.has_all_dags_access())
+ assert self.security_manager.has_all_dags_access()
mock_has_role.return_value = False
mock_has_perm.return_value = False
- self.assertFalse(self.security_manager.has_all_dags_access())
+ assert not self.security_manager.has_all_dags_access()
mock_has_perm.return_value = True
- self.assertTrue(self.security_manager.has_all_dags_access())
+ assert self.security_manager.has_all_dags_access()
def test_access_control_with_non_existent_role(self):
- with self.assertRaises(AirflowException) as context:
+ with pytest.raises(AirflowException) as ctx:
self.security_manager.sync_perm_for_dag(
dag_id='access-control-test',
access_control={
'this-role-does-not-exist': [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]
},
)
- self.assertIn("role does not exist", str(context.exception))
+ assert "role does not exist" in str(ctx.value)
def test_all_dag_access_doesnt_give_non_dag_access(self):
username = 'dag_access_user'
@@ -359,13 +358,11 @@ def test_all_dag_access_doesnt_give_non_dag_access(self):
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
],
)
- self.assertTrue(
- self.security_manager.has_access(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG, user)
+ assert self.security_manager.has_access(
+ permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG, user
)
- self.assertFalse(
- self.security_manager.has_access(
- permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE, user
- )
+ assert not self.security_manager.has_access(
+ permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE, user
)
def test_access_control_with_invalid_permission(self):
@@ -381,11 +378,11 @@ def test_access_control_with_invalid_permission(self):
)
for permission in invalid_permissions:
self.expect_user_is_in_role(user, rolename='team-a')
- with self.assertRaises(AirflowException) as context:
+ with pytest.raises(AirflowException) as ctx:
self.security_manager.sync_perm_for_dag(
'access_control_test', access_control={'team-a': {permission}}
)
- self.assertIn("invalid permissions", str(context.exception))
+ assert "invalid permissions" in str(ctx.value)
def test_access_control_is_set_on_init(self):
username = 'access_control_is_set_on_init'
@@ -448,9 +445,9 @@ def test_no_additional_dag_permission_views_created(self):
num_pv_before = self.db.session().query(ab_perm_view_role).count()
self.security_manager.sync_roles()
num_pv_after = self.db.session().query(ab_perm_view_role).count()
- self.assertEqual(num_pv_before, num_pv_after)
+ assert num_pv_before == num_pv_after
def test_override_role_vm(self):
test_security_manager = MockSecurityManager(appbuilder=self.appbuilder)
- self.assertEqual(len(test_security_manager.VIEWER_VMS), 1)
- self.assertEqual(test_security_manager.VIEWER_VMS, {'Airflow'})
+ assert len(test_security_manager.VIEWER_VMS) == 1
+ assert test_security_manager.VIEWER_VMS == {'Airflow'}
diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py
index b5088f4c17648..f6e53e4dbfbe1 100644
--- a/tests/www/test_utils.py
+++ b/tests/www/test_utils.py
@@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
+import re
import unittest
from datetime import datetime
from urllib.parse import parse_qs
@@ -30,17 +31,17 @@
class TestUtils(unittest.TestCase):
def test_empty_variable_should_not_be_hidden(self):
- self.assertFalse(utils.should_hide_value_for_key(""))
- self.assertFalse(utils.should_hide_value_for_key(None))
+ assert not utils.should_hide_value_for_key("")
+ assert not utils.should_hide_value_for_key(None)
def test_normal_variable_should_not_be_hidden(self):
- self.assertFalse(utils.should_hide_value_for_key("key"))
+ assert not utils.should_hide_value_for_key("key")
def test_sensitive_variable_should_be_hidden(self):
- self.assertTrue(utils.should_hide_value_for_key("google_api_key"))
+ assert utils.should_hide_value_for_key("google_api_key")
def test_sensitive_variable_should_be_hidden_ic(self):
- self.assertTrue(utils.should_hide_value_for_key("GOOGLE_API_KEY"))
+ assert utils.should_hide_value_for_key("GOOGLE_API_KEY")
@parameterized.expand(
[
@@ -55,7 +56,7 @@ def test_sensitive_variable_fields_should_be_hidden(
self, sensitive_variable_fields, key, expected_result
):
with conf_vars({('admin', 'sensitive_variable_fields'): str(sensitive_variable_fields)}):
- self.assertEqual(expected_result, utils.should_hide_value_for_key(key))
+ assert expected_result == utils.should_hide_value_for_key(key)
@parameterized.expand(
[
@@ -68,24 +69,24 @@ def test_normal_variable_fields_should_not_be_hidden(
self, sensitive_variable_fields, key, expected_result
):
with conf_vars({('admin', 'sensitive_variable_fields'): str(sensitive_variable_fields)}):
- self.assertEqual(expected_result, utils.should_hide_value_for_key(key))
+ assert expected_result == utils.should_hide_value_for_key(key)
def check_generate_pages_html(self, current_page, total_pages, window=7, check_middle=False):
extra_links = 4 # first, prev, next, last
search = "'>\"/>
"
html_str = utils.generate_pages(current_page, total_pages, search=search)
- self.assertNotIn(search, html_str, "The raw search string shouldn't appear in the output")
- self.assertIn('search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E', html_str)
+ assert search not in html_str, "The raw search string shouldn't appear in the output"
+ assert 'search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E' in html_str
- self.assertTrue(callable(html_str.__html__), "Should return something that is HTML-escaping aware")
+ assert callable(html_str.__html__), "Should return something that is HTML-escaping aware"
dom = BeautifulSoup(html_str, 'html.parser')
- self.assertIsNotNone(dom)
+ assert dom is not None
ulist = dom.ul
ulist_items = ulist.find_all('li')
- self.assertEqual(min(window, total_pages) + extra_links, len(ulist_items))
+ assert min(window, total_pages) + extra_links == len(ulist_items)
page_items = ulist_items[2:-2]
mid = int(len(page_items) / 2)
@@ -95,14 +96,14 @@ def check_generate_pages_html(self, current_page, total_pages, window=7, check_m
node_text = a_node.string
if node_text == str(current_page + 1):
if check_middle:
- self.assertEqual(mid, i)
- self.assertEqual('javascript:void(0)', href_link)
- self.assertIn('active', item['class'])
+ assert mid == i
+ assert 'javascript:void(0)' == href_link
+ assert 'active' in item['class']
else:
- self.assertRegex(href_link, r'^\?', 'Link is page-relative')
+ assert re.search(r'^\?', href_link), 'Link is page-relative'
query = parse_qs(href_link[1:])
- self.assertListEqual(query['page'], [str(int(node_text) - 1)])
- self.assertListEqual(query['search'], [search])
+ assert query['page'] == [str(int(node_text) - 1)]
+ assert query['search'] == [search]
def test_generate_pager_current_start(self):
self.check_generate_pages_html(current_page=0, total_pages=6)
@@ -115,25 +116,24 @@ def test_generate_pager_current_end(self):
def test_params_no_values(self):
"""Should return an empty string if no params are passed"""
- self.assertEqual('', utils.get_params())
+ assert '' == utils.get_params()
def test_params_search(self):
- self.assertEqual('search=bash_', utils.get_params(search='bash_'))
+ assert 'search=bash_' == utils.get_params(search='bash_')
def test_params_none_and_zero(self):
query_str = utils.get_params(a=0, b=None, c='true')
# The order won't be consistent, but that doesn't affect behaviour of a browser
pairs = list(sorted(query_str.split('&')))
- self.assertListEqual(['a=0', 'c=true'], pairs)
+ assert ['a=0', 'c=true'] == pairs
def test_params_all(self):
query = utils.get_params(status='active', page=3, search='bash_')
- self.assertEqual({'page': ['3'], 'search': ['bash_'], 'status': ['active']}, parse_qs(query))
+ assert {'page': ['3'], 'search': ['bash_'], 'status': ['active']} == parse_qs(query)
def test_params_escape(self):
- self.assertEqual(
- 'search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E',
- utils.get_params(search="'>\"/>
"),
+ assert 'search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E' == utils.get_params(
+ search="'>\"/>
"
)
def test_state_token(self):
@@ -141,14 +141,8 @@ def test_state_token(self):
# ensure they are escaped!
html = str(utils.state_token(''))
- self.assertIn(
- '<script>alert(1)</script>',
- html,
- )
- self.assertNotIn(
- '',
- html,
- )
+ assert '<script>alert(1)</script>' in html
+ assert '' not in html
def test_task_instance_link(self):
@@ -161,10 +155,10 @@ def test_task_instance_link(self):
)
)
- self.assertIn('%3Ca%261%3E', html)
- self.assertIn('%3Cb2%3E', html)
- self.assertNotIn('', html)
- self.assertNotIn('', html)
+ assert '%3Ca%261%3E' in html
+ assert '%3Cb2%3E' in html
+ assert '' not in html
+ assert '' not in html
def test_dag_link(self):
from airflow.www.app import cached_app
@@ -172,8 +166,8 @@ def test_dag_link(self):
with cached_app(testing=True).test_request_context():
html = str(utils.dag_link({'dag_id': '', 'execution_date': datetime.now()}))
- self.assertIn('%3Ca%261%3E', html)
- self.assertNotIn('', html)
+ assert '%3Ca%261%3E' in html
+ assert '' not in html
def test_dag_link_when_dag_is_none(self):
"""Test that when there is no dag_id, dag_link does not contain hyperlink"""
@@ -182,8 +176,8 @@ def test_dag_link_when_dag_is_none(self):
with cached_app(testing=True).test_request_context():
html = str(utils.dag_link({}))
- self.assertIn('None', html)
- self.assertNotIn('', 'run_id': '', 'execution_date': datetime.now()})
)
- self.assertIn('%3Ca%261%3E', html)
- self.assertIn('%3Cb2%3E', html)
- self.assertNotIn('', html)
- self.assertNotIn('', html)
+ assert '%3Ca%261%3E' in html
+ assert '%3Cb2%3E' in html
+ assert '' not in html
+ assert '' not in html
class TestAttrRenderer(unittest.TestCase):
@@ -208,34 +202,34 @@ def example_callable(unused_self):
print("example")
rendered = self.attr_renderer["python_callable"](example_callable)
- self.assertIn('"example"', rendered)
+ assert '"example"' in rendered
def test_python_callable_none(self):
rendered = self.attr_renderer["python_callable"](None)
- self.assertEqual("", rendered)
+ assert "" == rendered
def test_markdown(self):
markdown = "* foo\n* bar"
rendered = self.attr_renderer["doc_md"](markdown)
- self.assertIn("foo", rendered)
- self.assertIn("bar", rendered)
+ assert "foo" in rendered
+ assert "bar" in rendered
def test_markdown_none(self):
rendered = self.attr_renderer["python_callable"](None)
- self.assertEqual("", rendered)
+ assert "" == rendered
class TestWrappedMarkdown(unittest.TestCase):
def test_wrapped_markdown_with_docstring_curly_braces(self):
rendered = wrapped_markdown("{braces}", css_class="a_class")
- self.assertEqual('', rendered)
+ assert '' == rendered
def test_wrapped_markdown_with_some_markdown(self):
rendered = wrapped_markdown("*italic*\n**bold**\n", css_class="a_class")
- self.assertEqual(
+ assert (
'''''',
- rendered,
+bold
'''
+ == rendered
)
def test_wrapped_markdown_with_table(self):
@@ -245,12 +239,9 @@ def test_wrapped_markdown_with_table(self):
| ETL | 14m |"""
)
- self.assertEqual(
- (
- '\n\n\n| Job | \n'
- 'Duration | \n
\n\n\n\n| ETL'
- ' | \n14m | \n
\n\n'
- '
'
- ),
- rendered,
- )
+ assert (
+ '\n\n\n| Job | \n'
+ 'Duration | \n
\n\n\n\n| ETL'
+ ' | \n14m | \n
\n\n'
+ '
'
+ ) == rendered
diff --git a/tests/www/test_validators.py b/tests/www/test_validators.py
index df36b61f1d294..aff1524ee2715 100644
--- a/tests/www/test_validators.py
+++ b/tests/www/test_validators.py
@@ -19,6 +19,8 @@
import unittest
from unittest import mock
+import pytest
+
from airflow.www import validators
@@ -43,50 +45,47 @@ def _validate(self, fieldname=None, message=None):
return validator(self.form_mock, self.form_field_mock)
def test_field_not_found(self):
- self.assertRaisesRegex(
- validators.ValidationError,
- "^Invalid field name 'some'.$",
- self._validate,
- fieldname='some',
- )
+ with pytest.raises(validators.ValidationError, match="^Invalid field name 'some'.$"):
+ self._validate(
+ fieldname='some',
+ )
def test_form_field_is_none(self):
self.form_field_mock.data = None
- self.assertIsNone(self._validate())
+ assert self._validate() is None
def test_other_field_is_none(self):
self.other_field_mock.data = None
- self.assertIsNone(self._validate())
+ assert self._validate() is None
def test_both_fields_are_none(self):
self.form_field_mock.data = None
self.other_field_mock.data = None
- self.assertIsNone(self._validate())
+ assert self._validate() is None
def test_validation_pass(self):
- self.assertIsNone(self._validate())
+ assert self._validate() is None
def test_validation_raises(self):
self.form_field_mock.data = '2017-05-04'
- self.assertRaisesRegex(
- validators.ValidationError,
- "^Field must be greater than or equal to other field.$",
- self._validate,
- )
+ with pytest.raises(
+ validators.ValidationError, match="^Field must be greater than or equal to other field.$"
+ ):
+ self._validate()
def test_validation_raises_custom_message(self):
self.form_field_mock.data = '2017-05-04'
- self.assertRaisesRegex(
- validators.ValidationError,
- "^This field must be greater than or equal to MyField.$",
- self._validate,
- message="This field must be greater than or equal to MyField.",
- )
+ with pytest.raises(
+ validators.ValidationError, match="^This field must be greater than or equal to MyField.$"
+ ):
+ self._validate(
+ message="This field must be greater than or equal to MyField.",
+ )
class TestValidJson(unittest.TestCase):
@@ -105,26 +104,21 @@ def _validate(self, message=None):
def test_form_field_is_none(self):
self.form_field_mock.data = None
- self.assertIsNone(self._validate())
+ assert self._validate() is None
def test_validation_pass(self):
- self.assertIsNone(self._validate())
+ assert self._validate() is None
def test_validation_raises_default_message(self):
self.form_field_mock.data = '2017-05-04'
- self.assertRaisesRegex(
- validators.ValidationError,
- "JSON Validation Error:.*",
- self._validate,
- )
+ with pytest.raises(validators.ValidationError, match="JSON Validation Error:.*"):
+ self._validate()
def test_validation_raises_custom_message(self):
self.form_field_mock.data = '2017-05-04'
- self.assertRaisesRegex(
- validators.ValidationError,
- "Invalid JSON",
- self._validate,
- message="Invalid JSON: {}",
- )
+ with pytest.raises(validators.ValidationError, match="Invalid JSON"):
+ self._validate(
+ message="Invalid JSON: {}",
+ )
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index d05077f2f0e96..72615085e6408 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -194,21 +194,21 @@ def clear_table(cls, model):
def check_content_in_response(self, text, resp, resp_code=200):
resp_html = resp.data.decode('utf-8')
- self.assertEqual(resp_code, resp.status_code)
+ assert resp_code == resp.status_code
if isinstance(text, list):
for line in text:
- self.assertIn(line, resp_html)
+ assert line in resp_html
else:
- self.assertIn(text, resp_html)
+ assert text in resp_html
def check_content_not_in_response(self, text, resp, resp_code=200):
resp_html = resp.data.decode('utf-8')
- self.assertEqual(resp_code, resp.status_code)
+ assert resp_code == resp.status_code
if isinstance(text, list):
for line in text:
- self.assertNotIn(line, resp_html)
+ assert line not in resp_html
else:
- self.assertNotIn(text, resp_html)
+ assert text not in resp_html
@staticmethod
def percent_encode(obj):
@@ -289,8 +289,8 @@ def test_xss_prevention(self):
xss,
follow_redirects=True,
)
- self.assertEqual(resp.status_code, 404)
- self.assertNotIn("
", resp.data.decode("utf-8"))
+ assert resp.status_code == 404
+ assert "
" not in resp.data.decode("utf-8")
def test_import_variables_no_file(self):
resp = self.client.post('/variable/varimport', follow_redirects=True)
@@ -301,7 +301,7 @@ def test_import_variables_failed(self):
with mock.patch('airflow.models.Variable.set') as set_mock:
set_mock.side_effect = UnicodeEncodeError
- self.assertEqual(self.session.query(models.Variable).count(), 0)
+ assert self.session.query(models.Variable).count() == 0
try:
# python 3+
@@ -316,7 +316,7 @@ def test_import_variables_failed(self):
self.check_content_in_response('1 variable(s) failed to be updated.', resp)
def test_import_variables_success(self):
- self.assertEqual(self.session.query(models.Variable).count(), 0)
+ assert self.session.query(models.Variable).count() == 0
content = (
'{"str_key": "str_value", "int_key": 60, "list_key": [1, 2], "dict_key": {"k_a": 2, "k_b": 3}}'
@@ -431,17 +431,17 @@ def tearDownClass(cls):
def test_mount(self):
# Test an endpoint that doesn't need auth!
resp = self.client.get('/test/health')
- self.assertEqual(resp.status_code, 200)
- self.assertIn(b"healthy", resp.data)
+ assert resp.status_code == 200
+ assert b"healthy" in resp.data
def test_not_found(self):
resp = self.client.get('/', follow_redirects=True)
- self.assertEqual(resp.status_code, 404)
+ assert resp.status_code == 404
def test_index(self):
resp = self.client.get('/test/')
- self.assertEqual(resp.status_code, 302)
- self.assertEqual(resp.headers['Location'], 'http://localhost/test/home')
+ assert resp.status_code == 302
+ assert resp.headers['Location'] == 'http://localhost/test/home'
class TestAirflowBaseViews(TestBase):
@@ -519,11 +519,11 @@ def test_health(self):
resp_json = json.loads(self.client.get('health', follow_redirects=True).data.decode('utf-8'))
- self.assertEqual('healthy', resp_json['metadatabase']['status'])
- self.assertEqual('healthy', resp_json['scheduler']['status'])
- self.assertEqual(
- last_scheduler_heartbeat_for_testing_1.isoformat(),
- resp_json['scheduler']['latest_scheduler_heartbeat'],
+ assert 'healthy' == resp_json['metadatabase']['status']
+ assert 'healthy' == resp_json['scheduler']['status']
+ assert (
+ last_scheduler_heartbeat_for_testing_1.isoformat()
+ == resp_json['scheduler']['latest_scheduler_heartbeat']
)
self.session.query(BaseJob).filter(
@@ -551,11 +551,11 @@ def test_health(self):
resp_json = json.loads(self.client.get('health', follow_redirects=True).data.decode('utf-8'))
- self.assertEqual('healthy', resp_json['metadatabase']['status'])
- self.assertEqual('unhealthy', resp_json['scheduler']['status'])
- self.assertEqual(
- last_scheduler_heartbeat_for_testing_2.isoformat(),
- resp_json['scheduler']['latest_scheduler_heartbeat'],
+ assert 'healthy' == resp_json['metadatabase']['status']
+ assert 'unhealthy' == resp_json['scheduler']['status']
+ assert (
+ last_scheduler_heartbeat_for_testing_2.isoformat()
+ == resp_json['scheduler']['latest_scheduler_heartbeat']
)
self.session.query(BaseJob).filter(
@@ -573,9 +573,9 @@ def test_health(self):
resp_json = json.loads(self.client.get('health', follow_redirects=True).data.decode('utf-8'))
- self.assertEqual('healthy', resp_json['metadatabase']['status'])
- self.assertEqual('unhealthy', resp_json['scheduler']['status'])
- self.assertIsNone(None, resp_json['scheduler']['latest_scheduler_heartbeat'])
+ assert 'healthy' == resp_json['metadatabase']['status']
+ assert 'unhealthy' == resp_json['scheduler']['status']
+ assert resp_json['scheduler']['latest_scheduler_heartbeat'] is None
def test_home(self):
with self.capture_templates() as templates:
@@ -592,11 +592,11 @@ def test_home(self):
)
self.check_content_in_response(val_state_color_mapping, resp)
- self.assertEqual(len(templates), 1)
- self.assertEqual(templates[0].name, 'airflow/dags.html')
+ assert len(templates) == 1
+ assert templates[0].name == 'airflow/dags.html'
state_color_mapping = State.state_color.copy()
state_color_mapping["null"] = state_color_mapping.pop(None)
- self.assertEqual(templates[0].local_context['state_color'], state_color_mapping)
+ assert templates[0].local_context['state_color'] == state_color_mapping
def test_users_list(self):
resp = self.client.get('users/list', follow_redirects=True)
@@ -627,26 +627,26 @@ def test_home_filter_tags(self):
with self.client:
self.client.get('home?tags=example&tags=data', follow_redirects=True)
- self.assertEqual('example,data', flask_session[FILTER_TAGS_COOKIE])
+ assert 'example,data' == flask_session[FILTER_TAGS_COOKIE]
self.client.get('home?reset_tags', follow_redirects=True)
- self.assertIsNone(flask_session[FILTER_TAGS_COOKIE])
+ assert flask_session[FILTER_TAGS_COOKIE] is None
def test_home_status_filter_cookie(self):
from airflow.www.views import FILTER_STATUS_COOKIE
with self.client:
self.client.get('home', follow_redirects=True)
- self.assertEqual('all', flask_session[FILTER_STATUS_COOKIE])
+ assert 'all' == flask_session[FILTER_STATUS_COOKIE]
self.client.get('home?status=active', follow_redirects=True)
- self.assertEqual('active', flask_session[FILTER_STATUS_COOKIE])
+ assert 'active' == flask_session[FILTER_STATUS_COOKIE]
self.client.get('home?status=paused', follow_redirects=True)
- self.assertEqual('paused', flask_session[FILTER_STATUS_COOKIE])
+ assert 'paused' == flask_session[FILTER_STATUS_COOKIE]
self.client.get('home?status=all', follow_redirects=True)
- self.assertEqual('all', flask_session[FILTER_STATUS_COOKIE])
+ assert 'all' == flask_session[FILTER_STATUS_COOKIE]
def test_task(self):
url = 'task?task_id=runme_0&dag_id=example_bash_operator&execution_date={}'.format(
@@ -687,26 +687,26 @@ def test_rendered_k8s_without_k8s(self):
self.percent_encode(self.EXAMPLE_DAG_DEFAULT_DATE)
)
resp = self.client.get(url, follow_redirects=True)
- self.assertEqual(404, resp.status_code)
+ assert 404 == resp.status_code
def test_blocked(self):
url = 'blocked'
resp = self.client.post(url, follow_redirects=True)
- self.assertEqual(200, resp.status_code)
+ assert 200 == resp.status_code
def test_dag_stats(self):
resp = self.client.post('dag_stats', follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
+ assert resp.status_code == 200
def test_task_stats(self):
resp = self.client.post('task_stats', follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
- self.assertEqual(set(list(resp.json.items())[0][1][0].keys()), {'state', 'count'})
+ assert resp.status_code == 200
+ assert set(list(resp.json.items())[0][1][0].keys()) == {'state', 'count'}
@conf_vars({("webserver", "show_recent_stats_for_completed_runs"): "False"})
def test_task_stats_only_noncompleted(self):
resp = self.client.post('task_stats', follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
+ assert resp.status_code == 200
def test_dag_details(self):
url = 'dag_details?dag_id=example_bash_operator'
@@ -793,10 +793,10 @@ def test_last_dagruns_success_when_selecting_dags(self):
resp = self.client.post(
'last_dagruns', data={'dag_ids': ['example_subdag_operator']}, follow_redirects=True
)
- self.assertEqual(resp.status_code, 200)
+ assert resp.status_code == 200
stats = json.loads(resp.data.decode('utf-8'))
- self.assertNotIn('example_bash_operator', stats)
- self.assertIn('example_subdag_operator', stats)
+ assert 'example_bash_operator' not in stats
+ assert 'example_subdag_operator' in stats
# Multiple
resp = self.client.post(
@@ -805,8 +805,8 @@ def test_last_dagruns_success_when_selecting_dags(self):
follow_redirects=True,
)
stats = json.loads(resp.data.decode('utf-8'))
- self.assertIn('example_bash_operator', stats)
- self.assertIn('example_subdag_operator', stats)
+ assert 'example_bash_operator' in stats
+ assert 'example_subdag_operator' in stats
self.check_content_not_in_response('example_xcom', resp)
def test_tree(self):
@@ -1034,7 +1034,7 @@ def test_run_with_runnable_states(self, get_default_executor_function):
f"Task is in the '{state}' state which is not a valid state for execution. "
+ "The task must be cleared in order to be run"
)
- self.assertFalse(re.search(msg, resp.get_data(as_text=True)))
+ assert not re.search(msg, resp.get_data(as_text=True))
@mock.patch('airflow.executors.executor_loader.ExecutorLoader.get_default_executor')
def test_run_with_not_runnable_states(self, get_default_executor_function):
@@ -1043,7 +1043,7 @@ def test_run_with_not_runnable_states(self, get_default_executor_function):
task_id = 'runme_0'
for state in QUEUEABLE_STATES:
- self.assertFalse(state in RUNNABLE_STATES)
+ assert state not in RUNNABLE_STATES
self.session.query(models.TaskInstance).filter(models.TaskInstance.task_id == task_id).update(
{'state': state, 'end_date': timezone.utcnow()}
@@ -1066,7 +1066,7 @@ def test_run_with_not_runnable_states(self, get_default_executor_function):
f"Task is in the '{state}' state which is not a valid state for execution. "
+ "The task must be cleared in order to be run"
)
- self.assertTrue(re.search(msg, resp.get_data(as_text=True)))
+ assert re.search(msg, resp.get_data(as_text=True))
def test_refresh(self):
resp = self.client.post('refresh?dag_id=example_bash_operator')
@@ -1113,8 +1113,8 @@ def test_show_external_log_redirect_link_with_local_log_handler(self, endpoint):
with self.capture_templates() as templates:
self.client.get(url, follow_redirects=True)
ctx = templates[0].local_context
- self.assertFalse(ctx['show_external_log_redirect'])
- self.assertIsNone(ctx['external_log_name'])
+ assert not ctx['show_external_log_redirect']
+ assert ctx['external_log_name'] is None
@parameterized.expand(["graph", "tree"])
@mock.patch('airflow.utils.log.log_reader.TaskLogReader.log_handler', new_callable=PropertyMock)
@@ -1134,8 +1134,8 @@ def log_name(self):
with self.capture_templates() as templates:
self.client.get(url, follow_redirects=True)
ctx = templates[0].local_context
- self.assertTrue(ctx['show_external_log_redirect'])
- self.assertEqual(ctx['external_log_name'], ExternalHandler.LOG_NAME)
+ assert ctx['show_external_log_redirect']
+ assert ctx['external_log_name'] == ExternalHandler.LOG_NAME
class TestConfigurationView(TestBase):
@@ -1167,12 +1167,9 @@ def test_should_render_template(self):
resp = self.client.get('redoc')
self.check_content_in_response('Redoc', resp)
- self.assertEqual(len(templates), 1)
- self.assertEqual(templates[0].name, 'airflow/redoc.html')
- self.assertEqual(
- templates[0].local_context,
- {'openapi_spec_url': '/api/v1/openapi.yaml'},
- )
+ assert len(templates) == 1
+ assert templates[0].name == 'airflow/redoc.html'
+ assert templates[0].local_context == {'openapi_spec_url': '/api/v1/openapi.yaml'}
class TestLogView(TestBase):
@@ -1275,12 +1272,12 @@ def test_get_file_task_log(self, state, try_number, expected_num_logs_visible):
response = self.client.get(
TestLogView.ENDPOINT, data=dict(username='test', password='test'), follow_redirects=True
)
- self.assertEqual(response.status_code, 200)
- self.assertIn('Log by attempts', response.data.decode('utf-8'))
+ assert response.status_code == 200
+ assert 'Log by attempts' in response.data.decode('utf-8')
for num in range(1, expected_num_logs_visible + 1):
- self.assertIn(f'log-group-{num}', response.data.decode('utf-8'))
- self.assertNotIn('log-group-0', response.data.decode('utf-8'))
- self.assertNotIn('log-group-{}'.format(expected_num_logs_visible + 1), response.data.decode('utf-8'))
+ assert f'log-group-{num}' in response.data.decode('utf-8')
+ assert 'log-group-0' not in response.data.decode('utf-8')
+ assert 'log-group-{}'.format(expected_num_logs_visible + 1) not in response.data.decode('utf-8')
def test_get_logs_with_metadata_as_download_file(self):
url_template = (
@@ -1298,10 +1295,10 @@ def test_get_logs_with_metadata_as_download_file(self):
)
content_disposition = response.headers.get('Content-Disposition')
- self.assertTrue(content_disposition.startswith('attachment'))
- self.assertTrue(expected_filename in content_disposition)
- self.assertEqual(200, response.status_code)
- self.assertIn('Log for testing.', response.data.decode('utf-8'))
+ assert content_disposition.startswith('attachment')
+ assert expected_filename in content_disposition
+ assert 200 == response.status_code
+ assert 'Log for testing.' in response.data.decode('utf-8')
def test_get_logs_with_metadata_as_download_large_file(self):
with mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") as read_mock:
@@ -1325,10 +1322,10 @@ def test_get_logs_with_metadata_as_download_large_file(self):
)
response = self.client.get(url)
- self.assertIn('1st line', response.data.decode('utf-8'))
- self.assertIn('2nd line', response.data.decode('utf-8'))
- self.assertIn('3rd line', response.data.decode('utf-8'))
- self.assertNotIn('should never be read', response.data.decode('utf-8'))
+ assert '1st line' in response.data.decode('utf-8')
+ assert '2nd line' in response.data.decode('utf-8')
+ assert '3rd line' in response.data.decode('utf-8')
+ assert 'should never be read' not in response.data.decode('utf-8')
def test_get_logs_with_metadata(self):
url_template = (
@@ -1342,10 +1339,10 @@ def test_get_logs_with_metadata(self):
follow_redirects=True,
)
- self.assertIn('"message":', response.data.decode('utf-8'))
- self.assertIn('"metadata":', response.data.decode('utf-8'))
- self.assertIn('Log for testing.', response.data.decode('utf-8'))
- self.assertEqual(200, response.status_code)
+ assert '"message":' in response.data.decode('utf-8')
+ assert '"metadata":' in response.data.decode('utf-8')
+ assert 'Log for testing.' in response.data.decode('utf-8')
+ assert 200 == response.status_code
def test_get_logs_with_null_metadata(self):
url_template = (
@@ -1357,10 +1354,10 @@ def test_get_logs_with_null_metadata(self):
follow_redirects=True,
)
- self.assertIn('"message":', response.data.decode('utf-8'))
- self.assertIn('"metadata":', response.data.decode('utf-8'))
- self.assertIn('Log for testing.', response.data.decode('utf-8'))
- self.assertEqual(200, response.status_code)
+ assert '"message":' in response.data.decode('utf-8')
+ assert '"metadata":' in response.data.decode('utf-8')
+ assert 'Log for testing.' in response.data.decode('utf-8')
+ assert 200 == response.status_code
@mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read")
def test_get_logs_with_metadata_for_removed_dag(self, mock_read):
@@ -1373,10 +1370,10 @@ def test_get_logs_with_metadata_for_removed_dag(self, mock_read):
)
response = self.client.get(url, data=dict(username='test', password='test'), follow_redirects=True)
- self.assertIn('"message":', response.data.decode('utf-8'))
- self.assertIn('"metadata":', response.data.decode('utf-8'))
- self.assertIn('airflow log line', response.data.decode('utf-8'))
- self.assertEqual(200, response.status_code)
+ assert '"message":' in response.data.decode('utf-8')
+ assert '"metadata":' in response.data.decode('utf-8')
+ assert 'airflow log line' in response.data.decode('utf-8')
+ assert 200 == response.status_code
def test_get_logs_response_with_ti_equal_to_none(self):
url_template = (
@@ -1393,9 +1390,9 @@ def test_get_logs_response_with_ti_equal_to_none(self):
json.dumps({}),
)
response = self.client.get(url)
- self.assertIn('message', response.json)
- self.assertIn('error', response.json)
- self.assertEqual("*** Task instance did not exist in the DB\n", response.json['message'])
+ assert 'message' in response.json
+ assert 'error' in response.json
+ assert "*** Task instance did not exist in the DB\n" == response.json['message']
def test_get_logs_with_json_response_format(self):
url_template = (
@@ -1408,10 +1405,10 @@ def test_get_logs_with_json_response_format(self):
self.DAG_ID, self.TASK_ID, quote_plus(self.DEFAULT_DATE.isoformat()), try_number, json.dumps({})
)
response = self.client.get(url)
- self.assertIn('message', response.json)
- self.assertIn('metadata', response.json)
- self.assertIn('Log for testing.', response.json['message'][0][1])
- self.assertEqual(200, response.status_code)
+ assert 'message' in response.json
+ assert 'metadata' in response.json
+ assert 'Log for testing.' in response.json['message'][0][1]
+ assert 200 == response.status_code
@mock.patch("airflow.www.views.TaskLogReader")
def test_get_logs_for_handler_without_read_method(self, mock_log_reader):
@@ -1427,10 +1424,10 @@ def test_get_logs_for_handler_without_read_method(self, mock_log_reader):
self.DAG_ID, self.TASK_ID, quote_plus(self.DEFAULT_DATE.isoformat()), try_number, json.dumps({})
)
response = self.client.get(url)
- self.assertEqual(200, response.status_code)
- self.assertIn('message', response.json)
- self.assertIn('metadata', response.json)
- self.assertIn('Task log handler does not support read logs.', response.json['message'])
+ assert 200 == response.status_code
+ assert 'message' in response.json
+ assert 'metadata' in response.json
+ assert 'Task log handler does not support read logs.' in response.json['message']
@parameterized.expand(
[
@@ -1445,8 +1442,8 @@ def test_redirect_to_external_log_with_local_log_handler(self, task_id):
url = url_template.format(self.DAG_ID, task_id, quote_plus(self.DEFAULT_DATE.isoformat()), try_number)
response = self.client.get(url)
- self.assertEqual(302, response.status_code)
- self.assertEqual('http://localhost/home', response.headers['Location'])
+ assert 302 == response.status_code
+ assert 'http://localhost/home' == response.headers['Location']
@mock.patch('airflow.utils.log.log_reader.TaskLogReader.log_handler', new_callable=PropertyMock)
def test_redirect_to_external_log_with_external_log_handler(self, mock_log_handler):
@@ -1465,8 +1462,8 @@ def get_external_log_url(self, *args, **kwargs):
)
response = self.client.get(url)
- self.assertEqual(302, response.status_code)
- self.assertEqual(ExternalHandler.EXTERNAL_URL, response.headers['Location'])
+ assert 302 == response.status_code
+ assert ExternalHandler.EXTERNAL_URL == response.headers['Location']
class ViewWithDateTimeAndNumRunsAndDagRunsFormTester:
@@ -1853,7 +1850,7 @@ def test_permission_exist(self):
test_view_menu = self.appbuilder.sm.find_view_menu('DAG:example_bash_operator')
perms_views = self.appbuilder.sm.find_permissions_view_menu(test_view_menu)
- self.assertEqual(len(perms_views), 2)
+ assert len(perms_views) == 2
perms = [str(perm) for perm in perms_views]
expected_perms = [
@@ -1861,7 +1858,7 @@ def test_permission_exist(self):
'can edit on DAG:example_bash_operator',
]
for perm in expected_perms:
- self.assertIn(perm, perms)
+ assert perm in perms
def test_role_permission_associate(self):
self.create_user_and_login(
@@ -1875,8 +1872,8 @@ def test_role_permission_associate(self):
test_role = self.appbuilder.sm.find_role('role_permission_associate_role')
perms = {str(perm) for perm in test_role.permissions}
- self.assertIn('can edit on DAG:example_bash_operator', perms)
- self.assertIn('can read on DAG:example_bash_operator', perms)
+ assert 'can edit on DAG:example_bash_operator' in perms
+ assert 'can read on DAG:example_bash_operator' in perms
def test_index_success(self):
self.create_user_and_login(
@@ -1932,7 +1929,7 @@ def test_dag_stats_success(self):
resp = self.client.post('dag_stats', follow_redirects=True)
self.check_content_in_response('example_bash_operator', resp)
- self.assertEqual(set(list(resp.json.items())[0][1][0].keys()), {'state', 'count'})
+ assert set(list(resp.json.items())[0][1][0].keys()) == {'state', 'count'}
def test_dag_stats_failure(self):
self.logout()
@@ -1959,10 +1956,10 @@ def test_dag_stats_success_when_selecting_dags(self):
resp = self.client.post(
'dag_stats', data={'dag_ids': ['example_subdag_operator']}, follow_redirects=True
)
- self.assertEqual(resp.status_code, 200)
+ assert resp.status_code == 200
stats = json.loads(resp.data.decode('utf-8'))
- self.assertNotIn('example_bash_operator', stats)
- self.assertIn('example_subdag_operator', stats)
+ assert 'example_bash_operator' not in stats
+ assert 'example_subdag_operator' in stats
# Multiple
resp = self.client.post(
@@ -1971,8 +1968,8 @@ def test_dag_stats_success_when_selecting_dags(self):
follow_redirects=True,
)
stats = json.loads(resp.data.decode('utf-8'))
- self.assertIn('example_bash_operator', stats)
- self.assertIn('example_subdag_operator', stats)
+ assert 'example_bash_operator' in stats
+ assert 'example_subdag_operator' in stats
self.check_content_not_in_response('example_xcom', resp)
def test_task_stats_success(self):
@@ -2030,10 +2027,10 @@ def test_task_stats_success_when_selecting_dags(self):
resp = self.client.post(
'task_stats', data={'dag_ids': ['example_subdag_operator']}, follow_redirects=True
)
- self.assertEqual(resp.status_code, 200)
+ assert resp.status_code == 200
stats = json.loads(resp.data.decode('utf-8'))
- self.assertNotIn('example_bash_operator', stats)
- self.assertIn('example_subdag_operator', stats)
+ assert 'example_bash_operator' not in stats
+ assert 'example_subdag_operator' in stats
# Multiple
resp = self.client.post(
@@ -2042,8 +2039,8 @@ def test_task_stats_success_when_selecting_dags(self):
follow_redirects=True,
)
stats = json.loads(resp.data.decode('utf-8'))
- self.assertIn('example_bash_operator', stats)
- self.assertIn('example_subdag_operator', stats)
+ assert 'example_bash_operator' in stats
+ assert 'example_subdag_operator' in stats
self.check_content_not_in_response('example_xcom', resp)
def test_code_success(self):
@@ -2324,10 +2321,10 @@ def test_blocked_success_when_selecting_dags(self):
resp = self.client.post(
'blocked', data={'dag_ids': ['example_subdag_operator']}, follow_redirects=True
)
- self.assertEqual(resp.status_code, 200)
+ assert resp.status_code == 200
blocked_dags = {blocked['dag_id'] for blocked in json.loads(resp.data.decode('utf-8'))}
- self.assertNotIn('example_bash_operator', blocked_dags)
- self.assertIn('example_subdag_operator', blocked_dags)
+ assert 'example_bash_operator' not in blocked_dags
+ assert 'example_subdag_operator' in blocked_dags
# Multiple
resp = self.client.post(
@@ -2336,8 +2333,8 @@ def test_blocked_success_when_selecting_dags(self):
follow_redirects=True,
)
blocked_dags = {blocked['dag_id'] for blocked in json.loads(resp.data.decode('utf-8'))}
- self.assertIn('example_bash_operator', blocked_dags)
- self.assertIn('example_subdag_operator', blocked_dags)
+ assert 'example_bash_operator' in blocked_dags
+ assert 'example_subdag_operator' in blocked_dags
self.check_content_not_in_response('example_xcom', resp)
def test_failed_success(self):
@@ -2629,7 +2626,7 @@ def test_rendered_template_view(self):
"""
Test that the Rendered View contains the values from RenderedTaskInstanceFields
"""
- self.assertEqual(self.task1.bash_command, '{{ task_instance_key_str }}')
+ assert self.task1.bash_command == '{{ task_instance_key_str }}'
ti = TaskInstance(self.task1, self.default_date)
with create_session() as session:
@@ -2647,7 +2644,7 @@ def test_rendered_template_view_for_unexecuted_tis(self):
Test that the Rendered View is able to show rendered values
even for TIs that have not yet executed
"""
- self.assertEqual(self.task1.bash_command, '{{ task_instance_key_str }}')
+ assert self.task1.bash_command == '{{ task_instance_key_str }}'
url = 'rendered-templates?task_id=task1&dag_id=task1&execution_date={}'.format(
self.percent_encode(self.default_date)
@@ -2664,7 +2661,7 @@ def test_user_defined_filter_and_macros_raise_error(self):
self.app.dag_bag = mock.MagicMock(
**{'get_dag.return_value': SerializedDagModel.get(self.dag.dag_id).dag}
)
- self.assertEqual(self.task2.bash_command, 'echo {{ fullname("Apache", "Airflow") | hello }}')
+ assert self.task2.bash_command == 'echo {{ fullname("Apache", "Airflow") | hello }}'
url = 'rendered-templates?task_id=task2&dag_id=testdag&execution_date={}'.format(
self.percent_encode(self.default_date)
@@ -2689,8 +2686,8 @@ def setUp(self):
def test_trigger_dag_button_normal_exist(self):
resp = self.client.get('/', follow_redirects=True)
- self.assertIn('/trigger?dag_id=example_bash_operator', resp.data.decode('utf-8'))
- self.assertIn("return confirmDeleteDag(this, 'example_bash_operator')", resp.data.decode('utf-8'))
+ assert '/trigger?dag_id=example_bash_operator' in resp.data.decode('utf-8')
+ assert "return confirmDeleteDag(this, 'example_bash_operator')" in resp.data.decode('utf-8')
@pytest.mark.quarantined
def test_trigger_dag_button(self):
@@ -2704,9 +2701,9 @@ def test_trigger_dag_button(self):
self.client.post(f'trigger?dag_id={test_dag_id}')
run = self.session.query(DR).filter(DR.dag_id == test_dag_id).first()
- self.assertIsNotNone(run)
- self.assertIn(DagRunType.MANUAL, run.run_id)
- self.assertEqual(run.run_type, DagRunType.MANUAL)
+ assert run is not None
+ assert DagRunType.MANUAL in run.run_id
+ assert run.run_type == DagRunType.MANUAL
@pytest.mark.quarantined
def test_trigger_dag_conf(self):
@@ -2721,10 +2718,10 @@ def test_trigger_dag_conf(self):
self.client.post(f'trigger?dag_id={test_dag_id}', data={'conf': json.dumps(conf_dict)})
run = self.session.query(DR).filter(DR.dag_id == test_dag_id).first()
- self.assertIsNotNone(run)
- self.assertIn(DagRunType.MANUAL, run.run_id)
- self.assertEqual(run.run_type, DagRunType.MANUAL)
- self.assertEqual(run.conf, conf_dict)
+ assert run is not None
+ assert DagRunType.MANUAL in run.run_id
+ assert run.run_type == DagRunType.MANUAL
+ assert run.conf == conf_dict
def test_trigger_dag_conf_malformed(self):
test_dag_id = "example_bash_operator"
@@ -2737,7 +2734,7 @@ def test_trigger_dag_conf_malformed(self):
self.check_content_in_response('Invalid JSON configuration', response)
run = self.session.query(DR).filter(DR.dag_id == test_dag_id).first()
- self.assertIsNone(run)
+ assert run is None
def test_trigger_dag_form(self):
test_dag_id = "example_bash_operator"
@@ -2866,17 +2863,14 @@ def test_extra_links_works(self):
follow_redirects=True,
)
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
response_str = response.data
if isinstance(response.data, bytes):
response_str = response_str.decode()
- self.assertEqual(
- json.loads(response_str),
- {
- 'url': 'http://www.example.com/some_dummy_task/foo-bar/2017-01-01T00:00:00+00:00',
- 'error': None,
- },
- )
+ assert json.loads(response_str) == {
+ 'url': 'http://www.example.com/some_dummy_task/foo-bar/2017-01-01T00:00:00+00:00',
+ 'error': None,
+ }
def test_global_extra_links_works(self):
response = self.client.get(
@@ -2886,13 +2880,11 @@ def test_global_extra_links_works(self):
follow_redirects=True,
)
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
response_str = response.data
if isinstance(response.data, bytes):
response_str = response_str.decode()
- self.assertEqual(
- json.loads(response_str), {'url': 'https://github.com/apache/airflow', 'error': None}
- )
+ assert json.loads(response_str) == {'url': 'https://github.com/apache/airflow', 'error': None}
def test_extra_link_in_gantt_view(self):
exec_date = dates.days_ago(2)
@@ -2913,8 +2905,8 @@ def test_extra_link_in_gantt_view(self):
extra_links_grps = re.search(r'extraLinks\": \[(\".*?\")\]', resp.get_data(as_text=True))
extra_links = extra_links_grps.group(0)
- self.assertIn('airflow', extra_links)
- self.assertIn('github', extra_links)
+ assert 'airflow' in extra_links
+ assert 'github' in extra_links
def test_operator_extra_link_override_global_extra_link(self):
response = self.client.get(
@@ -2924,11 +2916,11 @@ def test_operator_extra_link_override_global_extra_link(self):
follow_redirects=True,
)
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
response_str = response.data
if isinstance(response.data, bytes):
response_str = response_str.decode()
- self.assertEqual(json.loads(response_str), {'url': 'https://airflow.apache.org', 'error': None})
+ assert json.loads(response_str) == {'url': 'https://airflow.apache.org', 'error': None}
def test_extra_links_error_raised(self):
response = self.client.get(
@@ -2938,11 +2930,11 @@ def test_extra_links_error_raised(self):
follow_redirects=True,
)
- self.assertEqual(404, response.status_code)
+ assert 404 == response.status_code
response_str = response.data
if isinstance(response.data, bytes):
response_str = response_str.decode()
- self.assertEqual(json.loads(response_str), {'url': None, 'error': 'This is an error'})
+ assert json.loads(response_str) == {'url': None, 'error': 'This is an error'}
def test_extra_links_no_response(self):
response = self.client.get(
@@ -2952,11 +2944,11 @@ def test_extra_links_no_response(self):
follow_redirects=True,
)
- self.assertEqual(response.status_code, 404)
+ assert response.status_code == 404
response_str = response.data
if isinstance(response.data, bytes):
response_str = response_str.decode()
- self.assertEqual(json.loads(response_str), {'url': None, 'error': 'No URL found for no_response'})
+ assert json.loads(response_str) == {'url': None, 'error': 'No URL found for no_response'}
def test_operator_extra_link_override_plugin(self):
"""
@@ -2973,13 +2965,11 @@ def test_operator_extra_link_override_plugin(self):
follow_redirects=True,
)
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
response_str = response.data
if isinstance(response.data, bytes):
response_str = response_str.decode()
- self.assertEqual(
- json.loads(response_str), {'url': 'https://airflow.apache.org/1.10.5/', 'error': None}
- )
+ assert json.loads(response_str) == {'url': 'https://airflow.apache.org/1.10.5/', 'error': None}
def test_operator_extra_link_multiple_operators(self):
"""
@@ -2997,13 +2987,11 @@ def test_operator_extra_link_multiple_operators(self):
follow_redirects=True,
)
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
response_str = response.data
if isinstance(response.data, bytes):
response_str = response_str.decode()
- self.assertEqual(
- json.loads(response_str), {'url': 'https://airflow.apache.org/1.10.5/', 'error': None}
- )
+ assert json.loads(response_str) == {'url': 'https://airflow.apache.org/1.10.5/', 'error': None}
response = self.client.get(
"{}?dag_id={}&task_id={}&execution_date={}&link_name=airflow".format(
@@ -3012,13 +3000,11 @@ def test_operator_extra_link_multiple_operators(self):
follow_redirects=True,
)
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
response_str = response.data
if isinstance(response.data, bytes):
response_str = response_str.decode()
- self.assertEqual(
- json.loads(response_str), {'url': 'https://airflow.apache.org/1.10.5/', 'error': None}
- )
+ assert json.loads(response_str) == {'url': 'https://airflow.apache.org/1.10.5/', 'error': None}
# Also check that the other Operator Link defined for this operator exists
response = self.client.get(
@@ -3028,11 +3014,11 @@ def test_operator_extra_link_multiple_operators(self):
follow_redirects=True,
)
- self.assertEqual(response.status_code, 200)
+ assert response.status_code == 200
response_str = response.data
if isinstance(response.data, bytes):
response_str = response_str.decode()
- self.assertEqual(json.loads(response_str), {'url': 'https://www.google.com', 'error': None})
+ assert json.loads(response_str) == {'url': 'https://www.google.com', 'error': None}
class TestDagRunModelView(TestBase):
@@ -3060,7 +3046,7 @@ def test_create_dagrun_execution_date_with_timezone_utc(self):
dr = self.session.query(models.DagRun).one()
- self.assertEqual(dr.execution_date, timezone.datetime(2018, 7, 6, 5, 4, 3))
+ assert dr.execution_date == timezone.datetime(2018, 7, 6, 5, 4, 3)
def test_create_dagrun_execution_date_with_timezone_edt(self):
data = {
@@ -3074,7 +3060,7 @@ def test_create_dagrun_execution_date_with_timezone_edt(self):
dr = self.session.query(models.DagRun).one()
- self.assertEqual(dr.execution_date, timezone.datetime(2018, 7, 6, 9, 4, 3))
+ assert dr.execution_date == timezone.datetime(2018, 7, 6, 9, 4, 3)
def test_create_dagrun_execution_date_with_timezone_pst(self):
data = {
@@ -3088,7 +3074,7 @@ def test_create_dagrun_execution_date_with_timezone_pst(self):
dr = self.session.query(models.DagRun).one()
- self.assertEqual(dr.execution_date, timezone.datetime(2018, 7, 6, 13, 4, 3))
+ assert dr.execution_date == timezone.datetime(2018, 7, 6, 13, 4, 3)
@conf_vars({("core", "default_timezone"): "America/Toronto"})
def test_create_dagrun_execution_date_without_timezone_default_edt(self):
@@ -3103,7 +3089,7 @@ def test_create_dagrun_execution_date_without_timezone_default_edt(self):
dr = self.session.query(models.DagRun).one()
- self.assertEqual(dr.execution_date, timezone.datetime(2018, 7, 6, 9, 4, 3))
+ assert dr.execution_date == timezone.datetime(2018, 7, 6, 9, 4, 3)
def test_create_dagrun_execution_date_without_timezone_default_utc(self):
data = {
@@ -3117,7 +3103,7 @@ def test_create_dagrun_execution_date_without_timezone_default_utc(self):
dr = self.session.query(models.DagRun).one()
- self.assertEqual(dr.execution_date, dt(2018, 7, 6, 5, 4, 3, tzinfo=timezone.TIMEZONE))
+ assert dr.execution_date == dt(2018, 7, 6, 5, 4, 3, tzinfo=timezone.TIMEZONE)
def test_create_dagrun_valid_conf(self):
conf_value = dict(Valid=True)
@@ -3132,7 +3118,7 @@ def test_create_dagrun_valid_conf(self):
resp = self.client.post('/dagrun/add', data=data, follow_redirects=True)
self.check_content_in_response('Added Row', resp)
dr = self.session.query(models.DagRun).one()
- self.assertEqual(dr.conf, conf_value)
+ assert dr.conf == conf_value
def test_create_dagrun_invalid_conf(self):
data = {
@@ -3146,7 +3132,7 @@ def test_create_dagrun_invalid_conf(self):
resp = self.client.post('/dagrun/add', data=data, follow_redirects=True)
self.check_content_in_response('JSON Validation Error:', resp)
dr = self.session.query(models.DagRun).all()
- self.assertFalse(dr)
+ assert not dr
def test_list_dagrun_includes_conf(self):
data = {
@@ -3158,8 +3144,8 @@ def test_list_dagrun_includes_conf(self):
}
self.client.post('/dagrun/add', data=data, follow_redirects=True)
dr = self.session.query(models.DagRun).one()
- self.assertEqual(dr.execution_date, timezone.convert_to_utc(datetime(2018, 7, 6, 5, 6, 3)))
- self.assertEqual(dr.conf, {"include": "me"})
+ assert dr.execution_date == timezone.convert_to_utc(datetime(2018, 7, 6, 5, 6, 3))
+ assert dr.conf == {"include": "me"}
resp = self.client.get('/dagrun/list', follow_redirects=True)
self.check_content_in_response("{"include": "me"}", resp)
@@ -3184,7 +3170,7 @@ def test_clear_dag_runs_action(self):
data = {"action": "clear", "rowid": [dr.id]}
resp = self.client.post("/dagrun/action_post", data=data, follow_redirects=True)
self.check_content_in_response("1 dag runs and 2 task instances were cleared", resp)
- self.assertEqual([ti.state for ti in self.session.query(models.TaskInstance).all()], [None, None])
+ assert [ti.state for ti in self.session.query(models.TaskInstance).all()] == [None, None]
def test_clear_dag_runs_action_fails(self):
data = {"action": "clear", "rowid": ["0"]}
@@ -3241,8 +3227,8 @@ def check_last_log(self, dag_id, event, execution_date=None):
if execution_date:
qry = qry.filter(Log.execution_date == execution_date)
logs = qry.order_by(Log.dttm.desc()).limit(5).all()
- self.assertGreaterEqual(len(logs), 1)
- self.assertTrue(logs[0].extra)
+ assert len(logs) >= 1
+ assert logs[0].extra
def test_action_logging_get(self):
url = 'graph?dag_id=example_bash_operator&execution_date={}'.format(
@@ -3297,4 +3283,4 @@ class TestHelperFunctions(TestBase):
def test_get_safe_url(self, test_url, expected_url, mock_url_for):
mock_url_for.return_value = "/home"
with self.app.test_request_context(base_url="http://localhost:8080"):
- self.assertEqual(get_safe_url(test_url), expected_url)
+ assert get_safe_url(test_url) == expected_url