diff --git a/airflow/hooks/dbapi.py b/airflow/hooks/dbapi.py index f86fe2681134c..5aabd03c7c6da 100644 --- a/airflow/hooks/dbapi.py +++ b/airflow/hooks/dbapi.py @@ -67,6 +67,8 @@ class DbApiHook(BaseHook): supports_autocommit = False # Override with the object that exposes the connect method connector = None # type: Optional[ConnectorProtocol] + # Override with db-specific query to check connection + _test_connection_sql = "select 1" def __init__(self, *args, schema: Optional[str] = None, **kwargs): super().__init__() @@ -346,10 +348,10 @@ def bulk_load(self, table, tmp_file): raise NotImplementedError() def test_connection(self): - """Tests the connection by executing a select 1 query""" + """Tests the connection using db-specific query""" status, message = False, '' try: - if self.get_first("select 1"): + if self.get_first(self._test_connection_sql): status = True message = 'Connection successfully tested' except Exception as e: diff --git a/airflow/providers/oracle/hooks/oracle.py b/airflow/providers/oracle/hooks/oracle.py index 6843d8f5c4469..84c0f5d6a1e33 100644 --- a/airflow/providers/oracle/hooks/oracle.py +++ b/airflow/providers/oracle/hooks/oracle.py @@ -339,3 +339,18 @@ def handler(cursor): ) return result + + # TODO: Merge this implementation back to DbApiHook when dropping + # support for Airflow 2.2. + def test_connection(self): + """Tests the connection by executing a select 1 from dual query""" + status, message = False, '' + try: + if self.get_first("select 1 from dual"): + status = True + message = 'Connection successfully tested' + except Exception as e: + status = False + message = str(e) + + return status, message diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py index c837c0d7db99b..db2e9f41ade98 100644 --- a/tests/providers/oracle/hooks/test_oracle.py +++ b/tests/providers/oracle/hooks/test_oracle.py @@ -347,3 +347,9 @@ def bindvar(value): expected = [1, 0, 0.0, False, ''] assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END;', expected)] assert result == expected + + def test_test_connection_use_dual_table(self): + status, message = self.db_hook.test_connection() + self.cur.execute.assert_called_once_with("select 1 from dual") + assert status is True + assert message == 'Connection successfully tested'