From 01139faad6339a2dcf600d279f17789874a42b6d Mon Sep 17 00:00:00 2001 From: vincbeck Date: Fri, 12 Apr 2024 14:52:02 -0300 Subject: [PATCH 1/2] Fix `DbApiHook.insert_rows` when `rows` is a generator --- airflow/providers/common/sql/hooks/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 4625c2e014f7a..462a647ca57bf 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -594,7 +594,7 @@ def insert_rows( conn.commit() self.log.info("Loaded %s rows into %s so far", i, table) conn.commit() - self.log.info("Done loading. Loaded a total of %s rows into %s", len(rows), table) + self.log.info("Done loading. Loaded a total of %s rows into %s", len(list(rows)), table) @classmethod def _serialize_cells(cls, row, conn=None): From 0af7f4f4fbf7c9d870efe5b7ae3be54c407b02fd Mon Sep 17 00:00:00 2001 From: vincbeck Date: Fri, 12 Apr 2024 17:22:21 -0300 Subject: [PATCH 2/2] Fix increment --- airflow/providers/common/sql/hooks/sql.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 462a647ca57bf..332bd996fd7fe 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -568,6 +568,7 @@ def insert_rows( stacklevel=2, ) + nb_rows = 0 with self._create_autocommit_connection() as conn: conn.commit() with closing(conn.cursor()) as cur: @@ -584,6 +585,7 @@ def insert_rows( cur.executemany(sql, values) conn.commit() self.log.info("Loaded %s rows into %s so far", len(chunked_rows), table) + nb_rows += len(chunked_rows) else: for i, row in enumerate(rows, 1): values = self._serialize_cells(row, conn) @@ -593,8 +595,9 @@ def insert_rows( if commit_every and i % commit_every == 0: conn.commit() self.log.info("Loaded %s rows into %s so far", i, table) + nb_rows += 1 conn.commit() - self.log.info("Done loading. Loaded a total of %s rows into %s", len(list(rows)), table) + self.log.info("Done loading. Loaded a total of %s rows into %s", nb_rows, table) @classmethod def _serialize_cells(cls, row, conn=None):