From a45c2b9135e3854888b1bbed4c76641abba685e9 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Wed, 30 Nov 2022 00:07:06 +0100 Subject: [PATCH] Fix wrapping of run() method result of exasol and snoflake DB hooks The change #27912 fixed and unified behaviour of DBApiHooks across the board, but it missed two places where sql was mis-used and overridden in exasol and snowflake hooks. The check for "sql" type did not use the original sql parameter value but the one that was overridden later in the run method implementation. The fix is the same as applied in Databricks Hook and DBAPI generic run methods - using consistent typing and separate variable to convert the sql string into sql list. Related: https://github.com/astronomer/astro-sdk/pull/1324 --- airflow/providers/exasol/hooks/exasol.py | 12 +++++++----- airflow/providers/snowflake/hooks/snowflake.py | 14 +++++++++----- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py index 3dc6b81973c78..b8dde88772dbd 100644 --- a/airflow/providers/exasol/hooks/exasol.py +++ b/airflow/providers/exasol/hooks/exasol.py @@ -159,19 +159,21 @@ def run( """ if isinstance(sql, str): if split_statements: - sql = self.split_sql_string(sql) + sql_list: Iterable[str] = self.split_sql_string(sql) else: - sql = [self.strip_sql_string(sql)] + sql_list = [self.strip_sql_string(sql)] + else: + sql_list = sql - if sql: - self.log.debug("Executing following statements against Exasol DB: %s", list(sql)) + if sql_list: + self.log.debug("Executing following statements against Exasol DB: %s", list(sql_list)) else: raise ValueError("List of SQL statements is empty") with closing(self.get_conn()) as conn: self.set_autocommit(conn, autocommit) results = [] - for sql_statement in sql: + for sql_statement in sql_list: with closing(conn.execute(sql_statement, parameters)) as cur: self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) if handler is not None: diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index c08247be21716..79bb6893180a9 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -353,12 +353,16 @@ def run( if isinstance(sql, str): if split_statements: split_statements_tuple = util_text.split_statements(StringIO(sql)) - sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string] + sql_list: Iterable[str] = [ + sql_string for sql_string, _ in split_statements_tuple if sql_string + ] else: - sql = [self.strip_sql_string(sql)] + sql_list = [self.strip_sql_string(sql)] + else: + sql_list = sql - if sql: - self.log.debug("Executing following statements against Snowflake DB: %s", list(sql)) + if sql_list: + self.log.debug("Executing following statements against Snowflake DB: %s", sql_list) else: raise ValueError("List of SQL statements is empty") @@ -368,7 +372,7 @@ def run( # SnowflakeCursor does not extend ContextManager, so we have to ignore mypy error here with closing(conn.cursor(DictCursor)) as cur: # type: ignore[type-var] results = [] - for sql_statement in sql: + for sql_statement in sql_list: self._run_command(cur, sql_statement, parameters) if handler is not None: