From 746efe7476cc811c50aeabd9f35aefd8dd9c418e Mon Sep 17 00:00:00 2001 From: Madison Swain-Bowden Date: Wed, 4 Oct 2023 20:58:00 -0700 Subject: [PATCH 1/3] Fix AWS RDS hook's DB instance state check --- airflow/providers/amazon/aws/hooks/rds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/hooks/rds.py b/airflow/providers/amazon/aws/hooks/rds.py index ce2c0d313ef09..3d4d484dec95e 100644 --- a/airflow/providers/amazon/aws/hooks/rds.py +++ b/airflow/providers/amazon/aws/hooks/rds.py @@ -240,7 +240,7 @@ def get_db_instance_state(self, db_instance_id: str) -> str: try: response = self.conn.describe_db_instances(DBInstanceIdentifier=db_instance_id) except self.conn.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "DBInstanceNotFoundFault": + if e.response["Error"]["Code"] == "DBInstanceNotFound": raise AirflowNotFoundException(e) raise e return response["DBInstances"][0]["DBInstanceStatus"].lower() From d583ff103fd73e5467f6bbf30348fb1f8be1e734 Mon Sep 17 00:00:00 2001 From: Madison Swain-Bowden Date: Fri, 3 Nov 2023 11:53:03 -0700 Subject: [PATCH 2/3] Replace exception checking string with specific exception match --- airflow/providers/amazon/aws/hooks/rds.py | 36 ++++++++--------------- 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/rds.py b/airflow/providers/amazon/aws/hooks/rds.py index 3d4d484dec95e..d577139f02da6 100644 --- a/airflow/providers/amazon/aws/hooks/rds.py +++ b/airflow/providers/amazon/aws/hooks/rds.py @@ -61,10 +61,8 @@ def get_db_snapshot_state(self, snapshot_id: str) -> str: """ try: response = self.conn.describe_db_snapshots(DBSnapshotIdentifier=snapshot_id) - except self.conn.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "DBSnapshotNotFound": - raise AirflowNotFoundException(e) - raise e + except self.conn.exceptions.DBSnapshotNotFoundFault as e: + raise AirflowNotFoundException(e) return response["DBSnapshots"][0]["Status"].lower() def wait_for_db_snapshot_state( @@ -109,10 +107,8 @@ def get_db_cluster_snapshot_state(self, snapshot_id: str) -> str: """ try: response = self.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=snapshot_id) - except self.conn.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "DBClusterSnapshotNotFoundFault": - raise AirflowNotFoundException(e) - raise e + except self.conn.exceptions.DBClusterSnapshotNotFoundFault as e: + raise AirflowNotFoundException(e) return response["DBClusterSnapshots"][0]["Status"].lower() def wait_for_db_cluster_snapshot_state( @@ -157,10 +153,8 @@ def get_export_task_state(self, export_task_id: str) -> str: """ try: response = self.conn.describe_export_tasks(ExportTaskIdentifier=export_task_id) - except self.conn.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "ExportTaskNotFoundFault": - raise AirflowNotFoundException(e) - raise e + except self.conn.exceptions.ExportTaskNotFoundFault as e: + raise AirflowNotFoundException(e) return response["ExportTasks"][0]["Status"].lower() def wait_for_export_task_state( @@ -198,10 +192,8 @@ def get_event_subscription_state(self, subscription_name: str) -> str: """ try: response = self.conn.describe_event_subscriptions(SubscriptionName=subscription_name) - except self.conn.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "SubscriptionNotFoundFault": - raise AirflowNotFoundException(e) - raise e + except self.conn.exceptions.SubscriptionNotFoundFault as e: + raise AirflowNotFoundException(e) return response["EventSubscriptionsList"][0]["Status"].lower() def wait_for_event_subscription_state( @@ -239,10 +231,8 @@ def get_db_instance_state(self, db_instance_id: str) -> str: """ try: response = self.conn.describe_db_instances(DBInstanceIdentifier=db_instance_id) - except self.conn.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "DBInstanceNotFound": - raise AirflowNotFoundException(e) - raise e + except self.conn.exceptions.DBInstanceNotFoundFault as e: + raise AirflowNotFoundException(e) return response["DBInstances"][0]["DBInstanceStatus"].lower() def wait_for_db_instance_state( @@ -292,10 +282,8 @@ def get_db_cluster_state(self, db_cluster_id: str) -> str: """ try: response = self.conn.describe_db_clusters(DBClusterIdentifier=db_cluster_id) - except self.conn.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "DBClusterNotFoundFault": - raise AirflowNotFoundException(e) - raise e + except self.conn.exceptions.DBClusterNotFoundFault as e: + raise AirflowNotFoundException(e) return response["DBClusters"][0]["Status"].lower() def wait_for_db_cluster_state( From 5b0daec16a9704433f066834f4554ef83c3456c0 Mon Sep 17 00:00:00 2001 From: Madison Swain-Bowden Date: Fri, 3 Nov 2023 13:26:44 -0700 Subject: [PATCH 3/3] Revert the two exceptions that don't raise specific faults --- airflow/providers/amazon/aws/hooks/rds.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/rds.py b/airflow/providers/amazon/aws/hooks/rds.py index d577139f02da6..1b84ff018f765 100644 --- a/airflow/providers/amazon/aws/hooks/rds.py +++ b/airflow/providers/amazon/aws/hooks/rds.py @@ -153,8 +153,10 @@ def get_export_task_state(self, export_task_id: str) -> str: """ try: response = self.conn.describe_export_tasks(ExportTaskIdentifier=export_task_id) - except self.conn.exceptions.ExportTaskNotFoundFault as e: - raise AirflowNotFoundException(e) + except self.conn.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "ExportTaskNotFoundFault": + raise AirflowNotFoundException(e) + raise e return response["ExportTasks"][0]["Status"].lower() def wait_for_export_task_state( @@ -192,8 +194,10 @@ def get_event_subscription_state(self, subscription_name: str) -> str: """ try: response = self.conn.describe_event_subscriptions(SubscriptionName=subscription_name) - except self.conn.exceptions.SubscriptionNotFoundFault as e: - raise AirflowNotFoundException(e) + except self.conn.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "SubscriptionNotFoundFault": + raise AirflowNotFoundException(e) + raise e return response["EventSubscriptionsList"][0]["Status"].lower() def wait_for_event_subscription_state(