From 0901bbe56a3d0caf5ed979311598d83cb90c43b7 Mon Sep 17 00:00:00 2001 From: Dmytro Kazanzhy Date: Fri, 29 Jul 2022 23:07:30 +0300 Subject: [PATCH 1/2] Fixing JdbcOperator non-SELECT statement run --- airflow/providers/jdbc/operators/jdbc.py | 9 +++++++-- tests/providers/jdbc/operators/test_jdbc.py | 16 ++++++++++++++-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/airflow/providers/jdbc/operators/jdbc.py b/airflow/providers/jdbc/operators/jdbc.py index 6b38366b41e81..9b9d580e8232e 100644 --- a/airflow/providers/jdbc/operators/jdbc.py +++ b/airflow/providers/jdbc/operators/jdbc.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union +from typing import TYPE_CHECKING, Callable, Iterable, Mapping, Optional, Sequence, Union from airflow.models import BaseOperator from airflow.providers.common.sql.hooks.sql import fetch_all_handler @@ -57,6 +57,7 @@ def __init__( jdbc_conn_id: str = 'jdbc_default', autocommit: bool = False, parameters: Optional[Union[Iterable, Mapping]] = None, + handler: Callable = fetch_all_handler, **kwargs, ) -> None: super().__init__(**kwargs) @@ -64,9 +65,13 @@ def __init__( self.sql = sql self.jdbc_conn_id = jdbc_conn_id self.autocommit = autocommit + self.handler = handler self.hook = None def execute(self, context: 'Context'): self.log.info('Executing: %s', self.sql) hook = JdbcHook(jdbc_conn_id=self.jdbc_conn_id) - return hook.run(self.sql, self.autocommit, parameters=self.parameters, handler=fetch_all_handler) + if self.do_xcom_push: + return hook.run(self.sql, self.autocommit, parameters=self.parameters, handler=self.handler) + else: + return hook.run(self.sql, self.autocommit, parameters=self.parameters) diff --git a/tests/providers/jdbc/operators/test_jdbc.py b/tests/providers/jdbc/operators/test_jdbc.py index 9168674c566a8..7b40e48340881 100644 --- a/tests/providers/jdbc/operators/test_jdbc.py +++ b/tests/providers/jdbc/operators/test_jdbc.py @@ -28,8 +28,8 @@ def setUp(self): self.kwargs = dict(sql='sql', task_id='test_jdbc_operator', dag=None) @patch('airflow.providers.jdbc.operators.jdbc.JdbcHook') - def test_execute(self, mock_jdbc_hook): - jdbc_operator = JdbcOperator(**self.kwargs) + def test_execute_do_push(self, mock_jdbc_hook): + jdbc_operator = JdbcOperator(**self.kwargs, do_xcom_push=True) jdbc_operator.execute(context={}) mock_jdbc_hook.assert_called_once_with(jdbc_conn_id=jdbc_operator.jdbc_conn_id) @@ -39,3 +39,15 @@ def test_execute(self, mock_jdbc_hook): parameters=jdbc_operator.parameters, handler=fetch_all_handler, ) + + @patch('airflow.providers.jdbc.operators.jdbc.JdbcHook') + def test_execute_dont_push(self, mock_jdbc_hook): + jdbc_operator = JdbcOperator(**self.kwargs, do_xcom_push=False) + jdbc_operator.execute(context={}) + + mock_jdbc_hook.assert_called_once_with(jdbc_conn_id=jdbc_operator.jdbc_conn_id) + mock_jdbc_hook.return_value.run.assert_called_once_with( + jdbc_operator.sql, + jdbc_operator.autocommit, + parameters=jdbc_operator.parameters, + ) From 5dba22c173bb55f867b9c819b8d381852c43f688 Mon Sep 17 00:00:00 2001 From: Dmytro Kazanzhy Date: Tue, 2 Aug 2022 19:29:25 +0300 Subject: [PATCH 2/2] Correct function type --- airflow/providers/jdbc/operators/jdbc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/jdbc/operators/jdbc.py b/airflow/providers/jdbc/operators/jdbc.py index 9b9d580e8232e..f45d112c43007 100644 --- a/airflow/providers/jdbc/operators/jdbc.py +++ b/airflow/providers/jdbc/operators/jdbc.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Callable, Iterable, Mapping, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Optional, Sequence, Union from airflow.models import BaseOperator from airflow.providers.common.sql.hooks.sql import fetch_all_handler @@ -57,7 +57,7 @@ def __init__( jdbc_conn_id: str = 'jdbc_default', autocommit: bool = False, parameters: Optional[Union[Iterable, Mapping]] = None, - handler: Callable = fetch_all_handler, + handler: Callable[[Any], Any] = fetch_all_handler, **kwargs, ) -> None: super().__init__(**kwargs)