From 6f530d7aca379647b64aae26878790392536980f Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Sat, 12 Feb 2022 03:04:43 +0530 Subject: [PATCH 01/12] Added template_ext = ('.json') to databricks operators #18925 reference: #18925 --- airflow/providers/databricks/operators/databricks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 3e3d64adc36ce..715c5041c6a41 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -246,6 +246,7 @@ class DatabricksSubmitRunOperator(BaseOperator): # Used in airflow.models.BaseOperator template_fields: Sequence[str] = ('json',) + template_ext: Sequence[str] = ('json',) # Databricks brand color (blue) under white text ui_color = '#1CB1C2' ui_fgcolor = '#fff' @@ -479,6 +480,7 @@ class DatabricksRunNowOperator(BaseOperator): # Used in airflow.models.BaseOperator template_fields: Sequence[str] = ('json',) + template_ext: Sequence[str] = ('json',) # Databricks brand color (blue) under white text ui_color = '#1CB1C2' ui_fgcolor = '#fff' From b5a2211d0fb2e1475dc96c30473100f0c6e969a0 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Sat, 12 Feb 2022 10:06:34 +0530 Subject: [PATCH 02/12] Corrected the template_ext value from 'json' to '.json' --- airflow/providers/databricks/operators/databricks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 715c5041c6a41..26c330815daa5 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -246,7 +246,7 @@ class DatabricksSubmitRunOperator(BaseOperator): # Used in airflow.models.BaseOperator template_fields: Sequence[str] = ('json',) - template_ext: Sequence[str] = ('json',) + template_ext: Sequence[str] = ('.json',) # Databricks brand color (blue) under white text ui_color = '#1CB1C2' ui_fgcolor = '#fff' @@ -480,7 +480,7 @@ class DatabricksRunNowOperator(BaseOperator): # Used in airflow.models.BaseOperator template_fields: Sequence[str] = ('json',) - template_ext: Sequence[str] = ('json',) + template_ext: Sequence[str] = ('.json',) # Databricks brand color (blue) under white text ui_color = '#1CB1C2' ui_fgcolor = '#fff' From c8dd3826f7f0fa90e8993b804e02b60d5eaeab36 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Sun, 13 Feb 2022 23:40:12 +0530 Subject: [PATCH 03/12] Added retries to LivyHook #19384 --- airflow/providers/apache/livy/hooks/livy.py | 20 ++++++++++++++++--- .../providers/apache/livy/operators/livy.py | 8 ++++++-- .../apache/livy/operators/test_livy.py | 12 +++++------ 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/airflow/providers/apache/livy/hooks/livy.py b/airflow/providers/apache/livy/hooks/livy.py index 865218c51e7b8..7f519ca0000c7 100644 --- a/airflow/providers/apache/livy/hooks/livy.py +++ b/airflow/providers/apache/livy/hooks/livy.py @@ -99,6 +99,7 @@ def run_method( method: str = 'GET', data: Optional[Any] = None, headers: Optional[Dict[str, Any]] = None, + _retry_args: Optional[Dict[str, Any]] = None ) -> Any: """ Wrapper for HttpHook, allows to change method on the same HttpHook @@ -107,6 +108,8 @@ def run_method( :param endpoint: endpoint :param data: request payload :param headers: headers + :param _retry_args: Arguments which define the retry behaviour. + See Tenacity documentation at https://github.com/jd/tenacity :return: http response :rtype: requests.Response """ @@ -118,7 +121,16 @@ def run_method( back_method = self.method self.method = method try: - result = self.run(endpoint, data, headers, self.extra_options) + if _retry_args: + result = self.run_with_advanced_retry( + endpoint=endpoint, + data=data, + headers=headers, + extra_options=self.extra_options, + _retry_args=_retry_args) + else: + result = self.run(endpoint, data, headers, self.extra_options) + finally: self.method = back_method return result @@ -180,18 +192,20 @@ def get_batch(self, session_id: Union[int, str]) -> Any: return response.json() - def get_batch_state(self, session_id: Union[int, str]) -> BatchState: + def get_batch_state(self, session_id: Union[int, str], _retry_args: Optional[Dict[str, Any]] = None) -> BatchState: """ Fetch the state of the specified batch :param session_id: identifier of the batch sessions + :param _retry_args: Arguments which define the retry behaviour. + See Tenacity documentation at https://github.com/jd/tenacity :return: batch state :rtype: BatchState """ self._validate_session_id(session_id) self.log.debug("Fetching info for batch session %d", session_id) - response = self.run_method(endpoint=f'/batches/{session_id}/state') + response = self.run_method(endpoint=f'/batches/{session_id}/state', _retry_args=_retry_args) try: response.raise_for_status() diff --git a/airflow/providers/apache/livy/operators/livy.py b/airflow/providers/apache/livy/operators/livy.py index 3b0a2bb93277c..7db60c26d1ddd 100644 --- a/airflow/providers/apache/livy/operators/livy.py +++ b/airflow/providers/apache/livy/operators/livy.py @@ -53,6 +53,8 @@ class LivyOperator(BaseOperator): :param extra_options: A dictionary of options, where key is string and value depends on the option that's being modified. :param extra_headers: A dictionary of headers passed to the HTTP request to livy. + :param _retry_args: Arguments which define the retry behaviour. + See Tenacity documentation at https://github.com/jd/tenacity """ template_fields: Sequence[str] = ('spark_params',) @@ -80,6 +82,7 @@ def __init__( polling_interval: int = 0, extra_options: Optional[Dict[str, Any]] = None, extra_headers: Optional[Dict[str, Any]] = None, + _retry_args: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: @@ -111,6 +114,7 @@ def __init__( self._livy_hook: Optional[LivyHook] = None self._batch_id: Union[int, str] + self._retry_args = _retry_args def get_hook(self) -> LivyHook: """ @@ -142,11 +146,11 @@ def poll_for_termination(self, batch_id: Union[int, str]) -> None: :param batch_id: id of the batch session to monitor. """ hook = self.get_hook() - state = hook.get_batch_state(batch_id) + state = hook.get_batch_state(batch_id, _retry_args=self._retry_args) while state not in hook.TERMINAL_STATES: self.log.debug('Batch with id %s is in state: %s', batch_id, state.value) sleep(self._polling_interval) - state = hook.get_batch_state(batch_id) + state = hook.get_batch_state(batch_id, _retry_args=self._retry_args) self.log.info("Batch with id %s terminated with state: %s", batch_id, state.value) hook.dump_batch_logs(batch_id) if state != BatchState.SUCCESS: diff --git a/tests/providers/apache/livy/operators/test_livy.py b/tests/providers/apache/livy/operators/test_livy.py index 8038f32125ec7..153f5a9644240 100644 --- a/tests/providers/apache/livy/operators/test_livy.py +++ b/tests/providers/apache/livy/operators/test_livy.py @@ -55,7 +55,7 @@ def test_poll_for_termination(self, mock_livy, mock_dump_logs): state_list = 2 * [BatchState.RUNNING] + [BatchState.SUCCESS] - def side_effect(_): + def side_effect(_, _retry_args): if state_list: return state_list.pop(0) # fail if does not stop right before @@ -67,7 +67,7 @@ def side_effect(_): task._livy_hook = task.get_hook() task.poll_for_termination(BATCH_ID) - mock_livy.assert_called_with(BATCH_ID) + mock_livy.assert_called_with(BATCH_ID, _retry_args=None) mock_dump_logs.assert_called_with(BATCH_ID) assert mock_livy.call_count == 3 @@ -80,7 +80,7 @@ def test_poll_for_termination_fail(self, mock_livy, mock_dump_logs): state_list = 2 * [BatchState.RUNNING] + [BatchState.ERROR] - def side_effect(_): + def side_effect(_, _retry_args): if state_list: return state_list.pop(0) # fail if does not stop right before @@ -94,7 +94,7 @@ def side_effect(_): with pytest.raises(AirflowException): task.poll_for_termination(BATCH_ID) - mock_livy.assert_called_with(BATCH_ID) + mock_livy.assert_called_with(BATCH_ID, _retry_args=None) mock_dump_logs.assert_called_with(BATCH_ID) assert mock_livy.call_count == 3 @@ -119,7 +119,7 @@ def test_execution(self, mock_post, mock_get, mock_dump_logs): call_args = {k: v for k, v in mock_post.call_args[1].items() if v} assert call_args == {'file': 'sparkapp'} - mock_get.assert_called_once_with(BATCH_ID) + mock_get.assert_called_once_with(BATCH_ID, _retry_args=None) mock_dump_logs.assert_called_once_with(BATCH_ID) @patch('airflow.providers.apache.livy.operators.livy.LivyHook.post_batch') @@ -171,5 +171,5 @@ def test_log_dump(self, mock_post, mock_get_logs, mock_get): assert 'INFO:airflow.providers.apache.livy.hooks.livy.LivyHook:first_line' in cm.output assert 'INFO:airflow.providers.apache.livy.hooks.livy.LivyHook:second_line' in cm.output assert 'INFO:airflow.providers.apache.livy.hooks.livy.LivyHook:third_line' in cm.output - mock_get.assert_called_once_with(BATCH_ID) + mock_get.assert_called_once_with(BATCH_ID, _retry_args=None) mock_get_logs.assert_called_once_with(BATCH_ID, 0, 100) From a7b58b8ead47b507df44be466adb451680bdde2b Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Mon, 14 Feb 2022 01:19:37 +0530 Subject: [PATCH 04/12] Static code fixes. --- airflow/providers/apache/livy/hooks/livy.py | 9 ++++++--- airflow/providers/apache/livy/operators/livy.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/airflow/providers/apache/livy/hooks/livy.py b/airflow/providers/apache/livy/hooks/livy.py index 7f519ca0000c7..9898073d59e94 100644 --- a/airflow/providers/apache/livy/hooks/livy.py +++ b/airflow/providers/apache/livy/hooks/livy.py @@ -99,7 +99,7 @@ def run_method( method: str = 'GET', data: Optional[Any] = None, headers: Optional[Dict[str, Any]] = None, - _retry_args: Optional[Dict[str, Any]] = None + _retry_args: Optional[Dict[str, Any]] = None, ) -> Any: """ Wrapper for HttpHook, allows to change method on the same HttpHook @@ -127,7 +127,8 @@ def run_method( data=data, headers=headers, extra_options=self.extra_options, - _retry_args=_retry_args) + _retry_args=_retry_args, + ) else: result = self.run(endpoint, data, headers, self.extra_options) @@ -192,7 +193,9 @@ def get_batch(self, session_id: Union[int, str]) -> Any: return response.json() - def get_batch_state(self, session_id: Union[int, str], _retry_args: Optional[Dict[str, Any]] = None) -> BatchState: + def get_batch_state( + self, session_id: Union[int, str], _retry_args: Optional[Dict[str, Any]] = None + ) -> BatchState: """ Fetch the state of the specified batch diff --git a/airflow/providers/apache/livy/operators/livy.py b/airflow/providers/apache/livy/operators/livy.py index 7db60c26d1ddd..2d713c28d8db2 100644 --- a/airflow/providers/apache/livy/operators/livy.py +++ b/airflow/providers/apache/livy/operators/livy.py @@ -82,7 +82,7 @@ def __init__( polling_interval: int = 0, extra_options: Optional[Dict[str, Any]] = None, extra_headers: Optional[Dict[str, Any]] = None, - _retry_args: Optional[Dict[str, Any]] = None, + _retry_args: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: From eb6543f63f894389a9c44163861ee77e444140cd Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Mon, 14 Feb 2022 21:47:18 +0530 Subject: [PATCH 05/12] Renamed _retry_args to retry_args. --- airflow/providers/apache/livy/hooks/livy.py | 14 +++++++------- airflow/providers/apache/livy/operators/livy.py | 10 +++++----- tests/providers/apache/livy/operators/test_livy.py | 12 ++++++------ 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/airflow/providers/apache/livy/hooks/livy.py b/airflow/providers/apache/livy/hooks/livy.py index 9898073d59e94..296a493ef363e 100644 --- a/airflow/providers/apache/livy/hooks/livy.py +++ b/airflow/providers/apache/livy/hooks/livy.py @@ -99,7 +99,7 @@ def run_method( method: str = 'GET', data: Optional[Any] = None, headers: Optional[Dict[str, Any]] = None, - _retry_args: Optional[Dict[str, Any]] = None, + retry_args: Optional[Dict[str, Any]] = None, ) -> Any: """ Wrapper for HttpHook, allows to change method on the same HttpHook @@ -108,7 +108,7 @@ def run_method( :param endpoint: endpoint :param data: request payload :param headers: headers - :param _retry_args: Arguments which define the retry behaviour. + :param retry_args: Arguments which define the retry behaviour. See Tenacity documentation at https://github.com/jd/tenacity :return: http response :rtype: requests.Response @@ -121,13 +121,13 @@ def run_method( back_method = self.method self.method = method try: - if _retry_args: + if retry_args: result = self.run_with_advanced_retry( endpoint=endpoint, data=data, headers=headers, extra_options=self.extra_options, - _retry_args=_retry_args, + retry_args=retry_args, ) else: result = self.run(endpoint, data, headers, self.extra_options) @@ -194,13 +194,13 @@ def get_batch(self, session_id: Union[int, str]) -> Any: return response.json() def get_batch_state( - self, session_id: Union[int, str], _retry_args: Optional[Dict[str, Any]] = None + self, session_id: Union[int, str], retry_args: Optional[Dict[str, Any]] = None ) -> BatchState: """ Fetch the state of the specified batch :param session_id: identifier of the batch sessions - :param _retry_args: Arguments which define the retry behaviour. + :param retry_args: Arguments which define the retry behaviour. See Tenacity documentation at https://github.com/jd/tenacity :return: batch state :rtype: BatchState @@ -208,7 +208,7 @@ def get_batch_state( self._validate_session_id(session_id) self.log.debug("Fetching info for batch session %d", session_id) - response = self.run_method(endpoint=f'/batches/{session_id}/state', _retry_args=_retry_args) + response = self.run_method(endpoint=f'/batches/{session_id}/state', retry_args=retry_args) try: response.raise_for_status() diff --git a/airflow/providers/apache/livy/operators/livy.py b/airflow/providers/apache/livy/operators/livy.py index 2d713c28d8db2..f0dbc9e3165fe 100644 --- a/airflow/providers/apache/livy/operators/livy.py +++ b/airflow/providers/apache/livy/operators/livy.py @@ -53,7 +53,7 @@ class LivyOperator(BaseOperator): :param extra_options: A dictionary of options, where key is string and value depends on the option that's being modified. :param extra_headers: A dictionary of headers passed to the HTTP request to livy. - :param _retry_args: Arguments which define the retry behaviour. + :param retry_args: Arguments which define the retry behaviour. See Tenacity documentation at https://github.com/jd/tenacity """ @@ -82,7 +82,7 @@ def __init__( polling_interval: int = 0, extra_options: Optional[Dict[str, Any]] = None, extra_headers: Optional[Dict[str, Any]] = None, - _retry_args: Optional[Dict[str, Any]] = None, + retry_args: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: @@ -114,7 +114,7 @@ def __init__( self._livy_hook: Optional[LivyHook] = None self._batch_id: Union[int, str] - self._retry_args = _retry_args + self.retry_args = retry_args def get_hook(self) -> LivyHook: """ @@ -146,11 +146,11 @@ def poll_for_termination(self, batch_id: Union[int, str]) -> None: :param batch_id: id of the batch session to monitor. """ hook = self.get_hook() - state = hook.get_batch_state(batch_id, _retry_args=self._retry_args) + state = hook.get_batch_state(batch_id, retry_args=self.retry_args) while state not in hook.TERMINAL_STATES: self.log.debug('Batch with id %s is in state: %s', batch_id, state.value) sleep(self._polling_interval) - state = hook.get_batch_state(batch_id, _retry_args=self._retry_args) + state = hook.get_batch_state(batch_id, retry_args=self.retry_args) self.log.info("Batch with id %s terminated with state: %s", batch_id, state.value) hook.dump_batch_logs(batch_id) if state != BatchState.SUCCESS: diff --git a/tests/providers/apache/livy/operators/test_livy.py b/tests/providers/apache/livy/operators/test_livy.py index 153f5a9644240..f1c43bf5f1759 100644 --- a/tests/providers/apache/livy/operators/test_livy.py +++ b/tests/providers/apache/livy/operators/test_livy.py @@ -55,7 +55,7 @@ def test_poll_for_termination(self, mock_livy, mock_dump_logs): state_list = 2 * [BatchState.RUNNING] + [BatchState.SUCCESS] - def side_effect(_, _retry_args): + def side_effect(_, retry_args): if state_list: return state_list.pop(0) # fail if does not stop right before @@ -67,7 +67,7 @@ def side_effect(_, _retry_args): task._livy_hook = task.get_hook() task.poll_for_termination(BATCH_ID) - mock_livy.assert_called_with(BATCH_ID, _retry_args=None) + mock_livy.assert_called_with(BATCH_ID, retry_args=None) mock_dump_logs.assert_called_with(BATCH_ID) assert mock_livy.call_count == 3 @@ -80,7 +80,7 @@ def test_poll_for_termination_fail(self, mock_livy, mock_dump_logs): state_list = 2 * [BatchState.RUNNING] + [BatchState.ERROR] - def side_effect(_, _retry_args): + def side_effect(_, retry_args): if state_list: return state_list.pop(0) # fail if does not stop right before @@ -94,7 +94,7 @@ def side_effect(_, _retry_args): with pytest.raises(AirflowException): task.poll_for_termination(BATCH_ID) - mock_livy.assert_called_with(BATCH_ID, _retry_args=None) + mock_livy.assert_called_with(BATCH_ID, retry_args=None) mock_dump_logs.assert_called_with(BATCH_ID) assert mock_livy.call_count == 3 @@ -119,7 +119,7 @@ def test_execution(self, mock_post, mock_get, mock_dump_logs): call_args = {k: v for k, v in mock_post.call_args[1].items() if v} assert call_args == {'file': 'sparkapp'} - mock_get.assert_called_once_with(BATCH_ID, _retry_args=None) + mock_get.assert_called_once_with(BATCH_ID, retry_args=None) mock_dump_logs.assert_called_once_with(BATCH_ID) @patch('airflow.providers.apache.livy.operators.livy.LivyHook.post_batch') @@ -171,5 +171,5 @@ def test_log_dump(self, mock_post, mock_get_logs, mock_get): assert 'INFO:airflow.providers.apache.livy.hooks.livy.LivyHook:first_line' in cm.output assert 'INFO:airflow.providers.apache.livy.hooks.livy.LivyHook:second_line' in cm.output assert 'INFO:airflow.providers.apache.livy.hooks.livy.LivyHook:third_line' in cm.output - mock_get.assert_called_once_with(BATCH_ID, _retry_args=None) + mock_get.assert_called_once_with(BATCH_ID, retry_args=None) mock_get_logs.assert_called_once_with(BATCH_ID, 0, 100) From 3e86860147154c472c6c44ee92736d888e26b483 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Mon, 14 Feb 2022 23:12:01 +0530 Subject: [PATCH 06/12] Passed _retry_args in run_with_advanced_retry(). --- airflow/providers/apache/livy/hooks/livy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/apache/livy/hooks/livy.py b/airflow/providers/apache/livy/hooks/livy.py index 296a493ef363e..e2d084ec61f78 100644 --- a/airflow/providers/apache/livy/hooks/livy.py +++ b/airflow/providers/apache/livy/hooks/livy.py @@ -127,7 +127,7 @@ def run_method( data=data, headers=headers, extra_options=self.extra_options, - retry_args=retry_args, + _retry_args=retry_args, ) else: result = self.run(endpoint, data, headers, self.extra_options) From a9b30ee0de0a5642b096b3447f8df17da3b2c9ab Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Sat, 19 Feb 2022 12:53:51 +0530 Subject: [PATCH 07/12] Added deprecation warning when using hql param. --- airflow/providers/presto/hooks/presto.py | 49 +++++++++++++++++++---- airflow/providers/trino/hooks/trino.py | 50 ++++++++++++++++++++---- 2 files changed, 83 insertions(+), 16 deletions(-) diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 83f680c9bb376..051b30107c7eb 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -111,27 +111,51 @@ def get_isolation_level(self) -> Any: def _strip_sql(sql: str) -> str: return sql.strip().rstrip(';') - def get_records(self, hql, parameters: Optional[dict] = None): + def get_records(self, sql, parameters: Optional[dict] = None, **kwargs): """Get a set of records from Presto""" + if kwargs.get('hql'): + warnings.warn( + "The hql parameter has been deprecated. You should pass the sql parameter.", + DeprecationWarning, + stacklevel=2, + ) + sql = kwargs.get('hql') + try: - return super().get_records(self._strip_sql(hql), parameters) + return super().get_records(self._strip_sql(sql), parameters) except DatabaseError as e: raise PrestoException(e) - def get_first(self, hql: str, parameters: Optional[dict] = None) -> Any: + def get_first(self, sql: str, parameters: Optional[dict] = None, **kwargs) -> Any: """Returns only the first row, regardless of how many rows the query returns.""" + if kwargs.get('hql'): + warnings.warn( + "The hql parameter has been deprecated. You should pass the sql parameter.", + DeprecationWarning, + stacklevel=2, + ) + sql = kwargs.get('hql') + try: - return super().get_first(self._strip_sql(hql), parameters) + return super().get_first(self._strip_sql(sql), parameters) except DatabaseError as e: raise PrestoException(e) - def get_pandas_df(self, hql, parameters=None, **kwargs): + def get_pandas_df(self, sql, parameters=None, **kwargs): """Get a pandas dataframe from a sql query.""" + if kwargs.get('hql'): + warnings.warn( + "The hql parameter has been deprecated. You should pass the sql parameter.", + DeprecationWarning, + stacklevel=2, + ) + sql = kwargs.get('hql') + import pandas cursor = self.get_cursor() try: - cursor.execute(self._strip_sql(hql), parameters) + cursor.execute(self._strip_sql(sql), parameters) data = cursor.fetchall() except DatabaseError as e: raise PrestoException(e) @@ -145,13 +169,22 @@ def get_pandas_df(self, hql, parameters=None, **kwargs): def run( self, - hql, + sql, autocommit: bool = False, parameters: Optional[dict] = None, handler: Optional[Callable] = None, + **kwargs ) -> None: """Execute the statement against Presto. Can be used to create views.""" - return super().run(sql=self._strip_sql(hql), parameters=parameters, handler=handler) + if kwargs.get('hql'): + warnings.warn( + "The hql parameter has been deprecated. You should pass the sql parameter.", + DeprecationWarning, + stacklevel=2, + ) + sql = kwargs.get('hql') + + return super().run(sql=self._strip_sql(sql), parameters=parameters, handler=handler) def insert_rows( self, diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 4ec6f301e7843..54d675e7f5dda 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. import os +import warnings from typing import Any, Callable, Iterable, Optional import trino @@ -106,27 +107,51 @@ def get_isolation_level(self) -> Any: def _strip_sql(sql: str) -> str: return sql.strip().rstrip(';') - def get_records(self, hql: str, parameters: Optional[dict] = None): + def get_records(self, sql: str = None, parameters: Optional[dict] = None, **kwargs): """Get a set of records from Trino""" + if kwargs.get('hql'): + warnings.warn( + "The hql parameter has been deprecated. You should pass the sql parameter.", + DeprecationWarning, + stacklevel=2, + ) + sql = kwargs.get('hql') + try: - return super().get_records(self._strip_sql(hql), parameters) + return super().get_records(self._strip_sql(sql), parameters) except DatabaseError as e: raise TrinoException(e) - def get_first(self, hql: str, parameters: Optional[dict] = None) -> Any: + def get_first(self, sql: str = None, parameters: Optional[dict] = None, **kwargs) -> Any: """Returns only the first row, regardless of how many rows the query returns.""" + if kwargs.get('hql'): + warnings.warn( + "The hql parameter has been deprecated. You should pass the sql parameter.", + DeprecationWarning, + stacklevel=2, + ) + sql = kwargs.get('hql') + try: - return super().get_first(self._strip_sql(hql), parameters) + return super().get_first(self._strip_sql(sql), parameters) except DatabaseError as e: raise TrinoException(e) - def get_pandas_df(self, hql: str, parameters: Optional[dict] = None, **kwargs): # type: ignore[override] + def get_pandas_df(self, sql: str = None, parameters: Optional[dict] = None, **kwargs): # type: ignore[override] """Get a pandas dataframe from a sql query.""" + if kwargs.get('hql'): + warnings.warn( + "The hql parameter has been deprecated. You should pass the sql parameter.", + DeprecationWarning, + stacklevel=2, + ) + sql = kwargs.get('hql') + import pandas cursor = self.get_cursor() try: - cursor.execute(self._strip_sql(hql), parameters) + cursor.execute(self._strip_sql(sql), parameters) data = cursor.fetchall() except DatabaseError as e: raise TrinoException(e) @@ -140,14 +165,23 @@ def get_pandas_df(self, hql: str, parameters: Optional[dict] = None, **kwargs): def run( self, - hql: str, + sql: str, autocommit: bool = False, parameters: Optional[dict] = None, handler: Optional[Callable] = None, + **kwargs ) -> None: """Execute the statement against Trino. Can be used to create views.""" + if kwargs.get('hql'): + warnings.warn( + "The hql parameter has been deprecated. You should pass the sql parameter.", + DeprecationWarning, + stacklevel=2, + ) + sql = kwargs.get('hql') + return super().run( - sql=self._strip_sql(hql), autocommit=autocommit, parameters=parameters, handler=handler + sql=self._strip_sql(sql), autocommit=autocommit, parameters=parameters, handler=handler ) def insert_rows( From e0bfed2a92d5caca5cc3e3456bf89d4cb98fc95f Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Sat, 19 Feb 2022 17:52:37 +0530 Subject: [PATCH 08/12] Fixed static checks. --- airflow/providers/presto/hooks/presto.py | 27 ++++++++++++----------- airflow/providers/trino/hooks/trino.py | 28 +++++++++++++----------- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 051b30107c7eb..48eb3a2331808 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. import os +import warnings from typing import Any, Callable, Iterable, Optional import prestodb @@ -111,45 +112,45 @@ def get_isolation_level(self) -> Any: def _strip_sql(sql: str) -> str: return sql.strip().rstrip(';') - def get_records(self, sql, parameters: Optional[dict] = None, **kwargs): + def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): """Get a set of records from Presto""" - if kwargs.get('hql'): + if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", DeprecationWarning, stacklevel=2, ) - sql = kwargs.get('hql') + sql = hql try: return super().get_records(self._strip_sql(sql), parameters) except DatabaseError as e: raise PrestoException(e) - def get_first(self, sql: str, parameters: Optional[dict] = None, **kwargs) -> Any: + def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: """Returns only the first row, regardless of how many rows the query returns.""" - if kwargs.get('hql'): + if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", DeprecationWarning, stacklevel=2, ) - sql = kwargs.get('hql') + sql = hql try: return super().get_first(self._strip_sql(sql), parameters) except DatabaseError as e: raise PrestoException(e) - def get_pandas_df(self, sql, parameters=None, **kwargs): + def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs): """Get a pandas dataframe from a sql query.""" - if kwargs.get('hql'): + if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", DeprecationWarning, stacklevel=2, ) - sql = kwargs.get('hql') + sql = hql import pandas @@ -169,20 +170,20 @@ def get_pandas_df(self, sql, parameters=None, **kwargs): def run( self, - sql, + sql: str = "", autocommit: bool = False, parameters: Optional[dict] = None, handler: Optional[Callable] = None, - **kwargs + hql: str = "", ) -> None: """Execute the statement against Presto. Can be used to create views.""" - if kwargs.get('hql'): + if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", DeprecationWarning, stacklevel=2, ) - sql = kwargs.get('hql') + sql = hql return super().run(sql=self._strip_sql(sql), parameters=parameters, handler=handler) diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 54d675e7f5dda..793aaac7946aa 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -107,45 +107,47 @@ def get_isolation_level(self) -> Any: def _strip_sql(sql: str) -> str: return sql.strip().rstrip(';') - def get_records(self, sql: str = None, parameters: Optional[dict] = None, **kwargs): + def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): """Get a set of records from Trino""" - if kwargs.get('hql'): + if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", DeprecationWarning, stacklevel=2, ) - sql = kwargs.get('hql') + sql = hql try: return super().get_records(self._strip_sql(sql), parameters) except DatabaseError as e: raise TrinoException(e) - def get_first(self, sql: str = None, parameters: Optional[dict] = None, **kwargs) -> Any: + def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: """Returns only the first row, regardless of how many rows the query returns.""" - if kwargs.get('hql'): + if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", DeprecationWarning, stacklevel=2, ) - sql = kwargs.get('hql') + sql = hql try: return super().get_first(self._strip_sql(sql), parameters) except DatabaseError as e: raise TrinoException(e) - def get_pandas_df(self, sql: str = None, parameters: Optional[dict] = None, **kwargs): # type: ignore[override] + def get_pandas_df( + self, sql: str = "", parameters: Optional[dict] = None, hql: str = "", **kwargs + ): # type: ignore[override] """Get a pandas dataframe from a sql query.""" - if kwargs.get('hql'): + if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", DeprecationWarning, stacklevel=2, ) - sql = kwargs.get('hql') + sql = hql import pandas @@ -165,20 +167,20 @@ def get_pandas_df(self, sql: str = None, parameters: Optional[dict] = None, **kw def run( self, - sql: str, + sql: str = "", autocommit: bool = False, parameters: Optional[dict] = None, handler: Optional[Callable] = None, - **kwargs + hql: str = "", ) -> None: """Execute the statement against Trino. Can be used to create views.""" - if kwargs.get('hql'): + if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", DeprecationWarning, stacklevel=2, ) - sql = kwargs.get('hql') + sql = hql return super().run( sql=self._strip_sql(sql), autocommit=autocommit, parameters=parameters, handler=handler From bd3f86954ef6c4862b4755e0390d489bf3720139 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Mon, 21 Feb 2022 12:44:08 +0530 Subject: [PATCH 09/12] Added @overload signatures and sphink's markers to remove 'hql' from documentation. --- airflow/providers/presto/hooks/presto.py | 68 ++++++++++++++++++++-- airflow/providers/trino/hooks/trino.py | 72 ++++++++++++++++++++++-- 2 files changed, 130 insertions(+), 10 deletions(-) diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 48eb3a2331808..1f6d3d3f0674d 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -17,7 +17,7 @@ # under the License. import os import warnings -from typing import Any, Callable, Iterable, Optional +from typing import Any, Callable, Iterable, Optional, overload import prestodb from prestodb.exceptions import DatabaseError @@ -112,8 +112,20 @@ def get_isolation_level(self) -> Any: def _strip_sql(sql: str) -> str: return sql.strip().rstrip(';') + @overload + def get_records(self, sql: str = "", parameters: Optional[dict] = None): + """Get a set of records from Presto + :param sql: SQL statement to be executed. + :param parameters: The parameters to render the SQL query with. + """ + ... + + @overload def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): - """Get a set of records from Presto""" + ... + + def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql=None): + """:sphinx-autoapi-skip:""" if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", @@ -127,8 +139,20 @@ def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str except DatabaseError as e: raise PrestoException(e) + @overload + def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any: + """Returns only the first row, regardless of how many rows the query returns. + :param sql: SQL statement to be executed. + :param parameters: The parameters to render the SQL query with. + """ + ... + + @overload def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: - """Returns only the first row, regardless of how many rows the query returns.""" + ... + + def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql=None) -> Any: + """:sphinx-autoapi-skip:""" if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", @@ -142,8 +166,20 @@ def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = except DatabaseError as e: raise PrestoException(e) + @overload + def get_pandas_df(self, sql: str = "", parameters=None, **kwargs): + """Get a pandas dataframe from a sql query. + :param sql: SQL statement to be executed. + :param parameters: The parameters to render the SQL query with. + """ + ... + + @overload def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs): - """Get a pandas dataframe from a sql query.""" + ... + + def get_pandas_df(self, sql: str = "", parameters=None, hql=None, **kwargs): + """:sphinx-autoapi-skip:""" if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", @@ -168,15 +204,37 @@ def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs) df = pandas.DataFrame(**kwargs) return df + @overload def run( self, sql: str = "", autocommit: bool = False, parameters: Optional[dict] = None, handler: Optional[Callable] = None, - hql: str = "", ) -> None: """Execute the statement against Presto. Can be used to create views.""" + ... + + @overload + def run( + self, + sql: str = "", + autocommit: bool = False, + parameters: Optional[dict] = None, + handler: Optional[Callable] = None, + hql: str = "", + ) -> None: + ... + + def run( + self, + sql: str = "", + autocommit: bool = False, + parameters: Optional[dict] = None, + handler: Optional[Callable] = None, + hql=None, + ) -> None: + """:sphinx-autoapi-skip:""" if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 793aaac7946aa..c0423ca7054cc 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -17,7 +17,7 @@ # under the License. import os import warnings -from typing import Any, Callable, Iterable, Optional +from typing import Any, Callable, Iterable, Optional, overload import trino from trino.exceptions import DatabaseError @@ -107,8 +107,20 @@ def get_isolation_level(self) -> Any: def _strip_sql(sql: str) -> str: return sql.strip().rstrip(';') + @overload + def get_records(self, sql: str = "", parameters: Optional[dict] = None): + """Get a set of records from Trino + :param sql: SQL statement to be executed. + :param parameters: The parameters to render the SQL query with. + """ + ... + + @overload def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): - """Get a set of records from Trino""" + ... + + def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql=None): + """:sphinx-autoapi-skip:""" if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", @@ -122,8 +134,20 @@ def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str except DatabaseError as e: raise TrinoException(e) + @overload + def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any: + """Returns only the first row, regardless of how many rows the query returns. + :param sql: SQL statement to be executed. + :param parameters: The parameters to render the SQL query with. + """ + ... + + @overload def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: - """Returns only the first row, regardless of how many rows the query returns.""" + ... + + def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql=None) -> Any: + """:sphinx-autoapi-skip:""" if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", @@ -137,10 +161,26 @@ def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = except DatabaseError as e: raise TrinoException(e) + @overload + def get_pandas_df( + self, sql: str = "", parameters: Optional[dict] = None, **kwargs + ): # type: ignore[override] + """Get a pandas dataframe from a sql query. + :param sql: SQL statement to be executed. + :param parameters: The parameters to render the SQL query with. + """ + ... + + @overload def get_pandas_df( self, sql: str = "", parameters: Optional[dict] = None, hql: str = "", **kwargs ): # type: ignore[override] - """Get a pandas dataframe from a sql query.""" + ... + + def get_pandas_df( + self, sql: str = "", parameters: Optional[dict] = None, hql=None, **kwargs + ): # type: ignore[override] + """:sphinx-autoapi-skip:""" if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", @@ -165,15 +205,37 @@ def get_pandas_df( df = pandas.DataFrame(**kwargs) return df + @overload def run( self, sql: str = "", autocommit: bool = False, parameters: Optional[dict] = None, handler: Optional[Callable] = None, - hql: str = "", ) -> None: """Execute the statement against Trino. Can be used to create views.""" + ... + + @overload + def run( + self, + sql: str = "", + autocommit: bool = False, + parameters: Optional[dict] = None, + handler: Optional[Callable] = None, + hql: str = "", + ) -> None: + ... + + def run( + self, + sql: str = "", + autocommit: bool = False, + parameters: Optional[dict] = None, + handler: Optional[Callable] = None, + hql=None, + ) -> None: + """:sphinx-autoapi-skip:""" if hql: warnings.warn( "The hql parameter has been deprecated. You should pass the sql parameter.", From 9db1cb13f14384a6dfd3288a7bdce61409e95dde Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Mon, 21 Feb 2022 13:20:14 +0530 Subject: [PATCH 10/12] Fixed docstings. --- airflow/providers/presto/hooks/presto.py | 7 +++---- airflow/providers/trino/hooks/trino.py | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 1f6d3d3f0674d..8f7c283fa735c 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -115,10 +115,10 @@ def _strip_sql(sql: str) -> str: @overload def get_records(self, sql: str = "", parameters: Optional[dict] = None): """Get a set of records from Presto + :param sql: SQL statement to be executed. :param parameters: The parameters to render the SQL query with. """ - ... @overload def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): @@ -142,10 +142,10 @@ def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql=None @overload def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any: """Returns only the first row, regardless of how many rows the query returns. + :param sql: SQL statement to be executed. :param parameters: The parameters to render the SQL query with. """ - ... @overload def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: @@ -169,10 +169,10 @@ def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql=None) @overload def get_pandas_df(self, sql: str = "", parameters=None, **kwargs): """Get a pandas dataframe from a sql query. + :param sql: SQL statement to be executed. :param parameters: The parameters to render the SQL query with. """ - ... @overload def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs): @@ -213,7 +213,6 @@ def run( handler: Optional[Callable] = None, ) -> None: """Execute the statement against Presto. Can be used to create views.""" - ... @overload def run( diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index c0423ca7054cc..5b8eac2661471 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -110,10 +110,10 @@ def _strip_sql(sql: str) -> str: @overload def get_records(self, sql: str = "", parameters: Optional[dict] = None): """Get a set of records from Trino + :param sql: SQL statement to be executed. :param parameters: The parameters to render the SQL query with. """ - ... @overload def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): @@ -137,10 +137,10 @@ def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql=None @overload def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any: """Returns only the first row, regardless of how many rows the query returns. + :param sql: SQL statement to be executed. :param parameters: The parameters to render the SQL query with. """ - ... @overload def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: @@ -166,10 +166,10 @@ def get_pandas_df( self, sql: str = "", parameters: Optional[dict] = None, **kwargs ): # type: ignore[override] """Get a pandas dataframe from a sql query. + :param sql: SQL statement to be executed. :param parameters: The parameters to render the SQL query with. """ - ... @overload def get_pandas_df( @@ -214,7 +214,6 @@ def run( handler: Optional[Callable] = None, ) -> None: """Execute the statement against Trino. Can be used to create views.""" - ... @overload def run( From 532fce2cd6102cba2c19aa8bf06173729aa3896b Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Mon, 21 Feb 2022 14:04:04 +0530 Subject: [PATCH 11/12] Added 'sphinx-autoapi-skip' in docstrings of overloaded fucntions. --- airflow/providers/presto/hooks/presto.py | 8 ++++---- airflow/providers/trino/hooks/trino.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 8f7c283fa735c..5d26f98028741 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -122,7 +122,7 @@ def get_records(self, sql: str = "", parameters: Optional[dict] = None): @overload def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): - ... + """:sphinx-autoapi-skip:""" def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql=None): """:sphinx-autoapi-skip:""" @@ -149,7 +149,7 @@ def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any: @overload def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: - ... + """:sphinx-autoapi-skip:""" def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql=None) -> Any: """:sphinx-autoapi-skip:""" @@ -176,7 +176,7 @@ def get_pandas_df(self, sql: str = "", parameters=None, **kwargs): @overload def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs): - ... + """:sphinx-autoapi-skip:""" def get_pandas_df(self, sql: str = "", parameters=None, hql=None, **kwargs): """:sphinx-autoapi-skip:""" @@ -223,7 +223,7 @@ def run( handler: Optional[Callable] = None, hql: str = "", ) -> None: - ... + """:sphinx-autoapi-skip:""" def run( self, diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 5b8eac2661471..e1685447b518f 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -117,7 +117,7 @@ def get_records(self, sql: str = "", parameters: Optional[dict] = None): @overload def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): - ... + """:sphinx-autoapi-skip:""" def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql=None): """:sphinx-autoapi-skip:""" @@ -144,7 +144,7 @@ def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any: @overload def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: - ... + """:sphinx-autoapi-skip:""" def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql=None) -> Any: """:sphinx-autoapi-skip:""" @@ -175,7 +175,7 @@ def get_pandas_df( def get_pandas_df( self, sql: str = "", parameters: Optional[dict] = None, hql: str = "", **kwargs ): # type: ignore[override] - ... + """:sphinx-autoapi-skip:""" def get_pandas_df( self, sql: str = "", parameters: Optional[dict] = None, hql=None, **kwargs @@ -224,7 +224,7 @@ def run( handler: Optional[Callable] = None, hql: str = "", ) -> None: - ... + """:sphinx-autoapi-skip:""" def run( self, From f7ea13a2b5b0737e42f7073abc85175298942538 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Mon, 21 Feb 2022 15:40:54 +0530 Subject: [PATCH 12/12] Fixed signature of methods in presto hook and trino hook. --- airflow/providers/presto/hooks/presto.py | 8 ++++---- airflow/providers/trino/hooks/trino.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 5d26f98028741..419b571c9f92b 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -124,7 +124,7 @@ def get_records(self, sql: str = "", parameters: Optional[dict] = None): def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): """:sphinx-autoapi-skip:""" - def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql=None): + def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): """:sphinx-autoapi-skip:""" if hql: warnings.warn( @@ -151,7 +151,7 @@ def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any: def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: """:sphinx-autoapi-skip:""" - def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql=None) -> Any: + def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: """:sphinx-autoapi-skip:""" if hql: warnings.warn( @@ -178,7 +178,7 @@ def get_pandas_df(self, sql: str = "", parameters=None, **kwargs): def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs): """:sphinx-autoapi-skip:""" - def get_pandas_df(self, sql: str = "", parameters=None, hql=None, **kwargs): + def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs): """:sphinx-autoapi-skip:""" if hql: warnings.warn( @@ -231,7 +231,7 @@ def run( autocommit: bool = False, parameters: Optional[dict] = None, handler: Optional[Callable] = None, - hql=None, + hql: str = "", ) -> None: """:sphinx-autoapi-skip:""" if hql: diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index e1685447b518f..2401c327d79b2 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -119,7 +119,7 @@ def get_records(self, sql: str = "", parameters: Optional[dict] = None): def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): """:sphinx-autoapi-skip:""" - def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql=None): + def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): """:sphinx-autoapi-skip:""" if hql: warnings.warn( @@ -146,7 +146,7 @@ def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any: def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: """:sphinx-autoapi-skip:""" - def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql=None) -> Any: + def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: """:sphinx-autoapi-skip:""" if hql: warnings.warn( @@ -178,7 +178,7 @@ def get_pandas_df( """:sphinx-autoapi-skip:""" def get_pandas_df( - self, sql: str = "", parameters: Optional[dict] = None, hql=None, **kwargs + self, sql: str = "", parameters: Optional[dict] = None, hql: str = "", **kwargs ): # type: ignore[override] """:sphinx-autoapi-skip:""" if hql: @@ -232,7 +232,7 @@ def run( autocommit: bool = False, parameters: Optional[dict] = None, handler: Optional[Callable] = None, - hql=None, + hql: str = "", ) -> None: """:sphinx-autoapi-skip:""" if hql: