Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from more_itertools import chunked
from psycopg2 import connect as ppg2_connect
from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor, execute_batch
from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor, execute_values

from airflow.providers.common.compat.sdk import (
AirflowException,
Expand Down Expand Up @@ -660,6 +660,10 @@ def insert_rows(
"""
Insert a collection of tuples into a table.

When ``fast_executemany=True`` with psycopg2, uses ``execute_values`` which batches
all rows into a single INSERT statement for better performance.
For psycopg3, the default ``executemany`` already uses pipelining for high performance.

Rows are inserted in chunks, each chunk (of size ``commit_every``) is
done in a new transaction.

Expand All @@ -668,20 +672,29 @@ def insert_rows(
:param target_fields: The names of the columns to fill in the table
:param commit_every: The maximum number of rows to insert in one
transaction. Set to 0 to insert all rows in one transaction.
:param replace: Whether to replace instead of insert
:param replace: Whether to replace instead of insert (uses ON CONFLICT)
:param executemany: If True, all rows are inserted at once in
chunks defined by the commit_every parameter. This only works if all rows
have same number of column names, but leads to better performance.
:param fast_executemany: If True, rows will be inserted using an optimized
bulk execution strategy (``psycopg2.extras.execute_batch``). This can
significantly improve performance for large inserts. If set to False,
the method falls back to the default implementation from
``DbApiHook.insert_rows``.
bulk execution strategy (``psycopg2.extras.execute_values``), unless psycopg3
is being used. This can significantly improve performance for large inserts.
If set to False or psycopg3 is being used, the method falls back to the default
implementation from ``DbApiHook.insert_rows``.
:param autocommit: What to set the connection's autocommit setting to
before executing the query.
"""
# if fast_executemany is disabled, defer to default implementation of insert_rows in DbApiHook
if not fast_executemany:
# psycopg3's executemany already uses pipelining, so use default implementation
# Only override for psycopg2 with fast_executemany to use execute_values
if USE_PSYCOPG3 and fast_executemany:
self.log.warning(
"fast_executemany=True has no effect when using psycopg3. "
"psycopg3's executemany already uses pipelining for optimal performance."
)
if USE_PSYCOPG3 or not fast_executemany:
# Reset to default format in case a previous fast_executemany call failed
self._insert_statement_format = "INSERT INTO {} {} VALUES ({})"

return super().insert_rows(
table,
rows,
Expand All @@ -693,9 +706,11 @@ def insert_rows(
**kwargs,
)

# if fast_executemany is enabled, use optimized execute_batch from psycopg
# if fast_executemany is enabled with psycopg2, use optimized execute_values from psycopg
self._insert_statement_format = "INSERT INTO {} {} VALUES %s"

nb_rows = 0
sql = None # not generated unless we actually process at least one chunk
sql: str | None = None # not generated unless we actually process at least one chunk
with self._create_autocommit_connection(autocommit) as conn:
conn.commit()
with closing(conn.cursor()) as cur:
Expand All @@ -710,7 +725,7 @@ def insert_rows(
self.log.debug("Generated sql: %s", sql)

try:
execute_batch(cur, sql, values, page_size=commit_every)
execute_values(cur, sql, values, page_size=commit_every)
except Exception as e:
self.log.error("Generated sql: %s", sql)
self.log.error("Parameters: %s", values)
Expand Down
48 changes: 40 additions & 8 deletions providers/postgres/tests/unit/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,8 +1012,8 @@ def test_insert_rows_hook_lineage(self, mock_send_lineage):
assert call_kw["sql"] == f"INSERT INTO {table} VALUES (%s)"
assert call_kw["row_count"] == 2

@mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch")
def test_insert_rows_fast_executemany(self, mock_execute_batch):
@mock.patch("airflow.providers.postgres.hooks.postgres.execute_values")
def test_insert_rows_fast_executemany(self, mock_execute_values):
table = "table"
rows = [("hello",), ("world",)]

Expand All @@ -1025,9 +1025,9 @@ def test_insert_rows_fast_executemany(self, mock_execute_batch):
commit_count = 2 # The first and last commit
assert self.conn.commit.call_count == commit_count

mock_execute_batch.assert_called_once_with(
mock_execute_values.assert_called_once_with(
self.cur,
f"INSERT INTO {table} VALUES (%s)", # expected SQL
f"INSERT INTO {table} VALUES %s", # expected SQL
[("hello",), ("world",)], # expected values
page_size=1000,
)
Expand All @@ -1036,9 +1036,8 @@ def test_insert_rows_fast_executemany(self, mock_execute_batch):
self.cur.executemany.assert_not_called()

@mock.patch("airflow.providers.postgres.hooks.postgres.send_sql_hook_lineage")
@mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch")
def test_insert_rows_fast_executemany_hook_lineage(self, mock_execute_batch, mock_send_lineage):

@mock.patch("airflow.providers.postgres.hooks.postgres.execute_values")
def test_insert_rows_fast_executemany_hook_lineage(self, mock_execute_values, mock_send_lineage):
table = "table"
rows = [("hello",), ("world",)]

Expand All @@ -1047,9 +1046,28 @@ def test_insert_rows_fast_executemany_hook_lineage(self, mock_execute_batch, moc
mock_send_lineage.assert_called_once()
call_kw = mock_send_lineage.call_args.kwargs
assert call_kw["context"] is self.db_hook
assert call_kw["sql"] == f"INSERT INTO {table} VALUES (%s)"
assert call_kw["sql"] == f"INSERT INTO {table} VALUES %s"
assert call_kw["row_count"] == 2

@mock.patch("airflow.providers.postgres.hooks.postgres.USE_PSYCOPG3", True)
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.insert_rows")
def test_insert_rows_fast_executemany_psycopg3_fallback(self, mock_super_insert_rows):
"""Verify psycopg3 falls back to default implementation even with fast_executemany=True."""
table = "table"
rows = [("hello",), ("world",)]

self.db_hook.insert_rows(table, rows, fast_executemany=True)

mock_super_insert_rows.assert_called_once_with(
table,
rows,
target_fields=None,
commit_every=1000,
replace=False,
executemany=False,
autocommit=False,
)

@pytest.mark.usefixtures("reset_logging_config")
def test_get_all_db_log_messages(self, mocker):
messages = ["a", "b", "c"]
Expand Down Expand Up @@ -1207,3 +1225,17 @@ def test_log_db_messages_by_db_proc(self, mocker):
mock_logger.info.assert_any_call("Message from db: 42")
finally:
hook.run(sql=f"DROP PROCEDURE {proc_name} (s text)")

@pytest.mark.usefixtures("reset_logging_config")
def test_insert_rows_fast_executemany_psycopg3_logs_warning(self, mocker):
mock_logger = mocker.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.log")

table = "table"
rows = [("hello",), ("world",)]

self.db_hook.insert_rows(table, rows, fast_executemany=True)

mock_logger.warning.assert_called_once_with(
"fast_executemany=True has no effect when using psycopg3. "
"psycopg3's executemany already uses pipelining for optimal performance."
)
Loading