From 971436962c13b1c0b4400378149102642f3b04fc Mon Sep 17 00:00:00 2001 From: Sameer Mesiah Date: Sat, 16 May 2026 17:05:27 +0100 Subject: [PATCH] Add configurable UPSERT update fields to PostgresHook Extend PostgreSQL ON CONFLICT support by allowing callers to specify which columns are updated when conflicts occur. Preserve existing behavior when no update fields are provided, support DO NOTHING semantics via an empty update field list, and add an upsert_rows convenience wrapper built on top of the existing insert_rows(replace=True) implementation. --- .../providers/postgres/dialects/postgres.py | 12 +++- .../providers/postgres/hooks/postgres.py | 41 ++++++++++++- .../unit/postgres/dialects/test_postgres.py | 57 +++++++++++++++++++ .../unit/postgres/hooks/test_postgres.py | 29 ++++++++++ 4 files changed, 136 insertions(+), 3 deletions(-) diff --git a/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py b/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py index 446ea199fa154..3635cca98c396 100644 --- a/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py @@ -106,9 +106,11 @@ def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: :param table: Name of the target table :param values: The row to insert into the table :param target_fields: The names of the columns to fill in the table - :param replace: Whether to replace instead of insert :param replace_index: the column or list of column names to act as index for the ON CONFLICT clause + :param replace_target: Column name or list of column names to update when + a conflict occurs. If omitted, all non-conflict columns are updated. + If an empty list is provided, ``DO NOTHING`` is used. :return: The generated INSERT or REPLACE SQL statement """ if not target_fields: @@ -124,7 +126,13 @@ def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: sql = self.generate_insert_sql(table, values, target_fields, **kwargs) on_conflict_str = f" ON CONFLICT ({', '.join(map(self.escape_word, replace_index))})" - replace_target = [self.escape_word(f) for f in target_fields if f not in replace_index] + + replace_target = kwargs.get("replace_target") + + if replace_target is None: + replace_target = [self.escape_word(f) for f in target_fields if f not in replace_index] + else: + replace_target = [self.escape_word(f) for f in replace_target] if replace_target: replace_target_str = ", ".join(f"{col} = excluded.{col}" for col in replace_target) diff --git a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py index ec28bcc834fed..7241eb1a4af59 100644 --- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py @@ -18,7 +18,7 @@ from __future__ import annotations import os -from collections.abc import Mapping +from collections.abc import Iterable, Mapping from contextlib import closing from copy import deepcopy from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, cast, overload @@ -726,3 +726,42 @@ def insert_rows( self.log.info("Done loading. Loaded a total of %s rows into %s", nb_rows, table) return None + + def upsert_rows( + self, + table: str, + rows: Iterable[tuple[Any, ...]], + target_fields: list[str], + conflict_fields: list[str], + update_fields: list[str] | None = None, + commit_every: int = 1000, + *, + fast_executemany: bool = False, + autocommit: bool = False, + ) -> None: + """ + Upsert rows into a PostgreSQL table using ``ON CONFLICT``. + + :param table: Name of the target table. + :param rows: Rows to upsert. + :param target_fields: Non-empty column names used in the ``INSERT`` statement. + :param conflict_fields: Non-empty column names used in the ``ON CONFLICT`` clause. + :param update_fields: Columns updated on conflict. If omitted, all + non-conflict columns are updated. If an empty list is provided, + conflicting rows are ignored via ``DO NOTHING``. + :param commit_every: Maximum number of rows per transaction. Default value is 1000. + :param fast_executemany: Use ``psycopg2.extras.execute_batch`` for improved + batch performance. + :param autocommit: Connection autocommit setting. + """ + return self.insert_rows( + table=table, + rows=rows, + target_fields=target_fields, + replace_index=conflict_fields, + replace_target=update_fields, + commit_every=commit_every, + replace=True, + fast_executemany=fast_executemany, + autocommit=autocommit, + ) diff --git a/providers/postgres/tests/unit/postgres/dialects/test_postgres.py b/providers/postgres/tests/unit/postgres/dialects/test_postgres.py index 999386baac12e..946805ec40fbd 100644 --- a/providers/postgres/tests/unit/postgres/dialects/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/dialects/test_postgres.py @@ -19,6 +19,8 @@ from unittest.mock import MagicMock +import pytest + from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.postgres.dialects.postgres import PostgresDialect @@ -103,3 +105,58 @@ def test_generate_replace_sql_when_escape_column_names_is_enabled(self): INSERT INTO hollywood.actors ("id", "name", "firstname", "age") VALUES (?,?,?,?,?) ON CONFLICT ("id") DO UPDATE SET "name" = excluded."name", "firstname" = excluded."firstname", "age" = excluded."age" """.strip() ) + + @pytest.mark.parametrize( + ("replace_index", "replace_target", "expected_clause"), + [ + ( + None, + ["name"], + "ON CONFLICT (id) DO UPDATE SET name = excluded.name", + ), + ( + None, + ["name", "age"], + "ON CONFLICT (id) DO UPDATE SET name = excluded.name, age = excluded.age", + ), + ( + None, + [], + "ON CONFLICT (id) DO NOTHING", + ), + ( + ["id", "name"], + ["age"], + "ON CONFLICT (id, name) DO UPDATE SET age = excluded.age", + ), + ], + ) + def test_generate_replace_sql_with_replace_target( + self, + replace_index, + replace_target, + expected_clause, + ): + values = [ + {"id": 1, "name": "Stallone", "firstname": "Sylvester", "age": "78"}, + {"id": 2, "name": "Statham", "firstname": "Jason", "age": "57"}, + {"id": 3, "name": "Li", "firstname": "Jet", "age": "61"}, + {"id": 4, "name": "Lundgren", "firstname": "Dolph", "age": "66"}, + {"id": 5, "name": "Norris", "firstname": "Chuck", "age": "84"}, + ] + + target_fields = ["id", "name", "firstname", "age"] + + sql = PostgresDialect(self.test_db_hook).generate_replace_sql( + "hollywood.actors", + values, + target_fields, + replace_index=replace_index, + replace_target=replace_target, + ) + + assert ( + sql + == f""" + INSERT INTO hollywood.actors (id, name, firstname, age) VALUES (?,?,?,?,?) {expected_clause}""".strip() + ) diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py index fab8bc6f5d7d5..7825d98813209 100644 --- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py @@ -911,6 +911,35 @@ def test_insert_rows_replace_all_index(self): ) self.cur.executemany.assert_any_call(sql, rows) + @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.insert_rows") + def test_upsert_rows(self, mock_insert_rows): + + rows = [(1, "hello")] + table = "table" + + self.db_hook.upsert_rows( + table=table, + rows=rows, + target_fields=["id", "value"], + conflict_fields=["id"], + update_fields=["value"], + commit_every=123, + fast_executemany=True, + autocommit=True, + ) + + mock_insert_rows.assert_called_once_with( + table=table, + rows=rows, + target_fields=["id", "value"], + replace_index=["id"], + replace_target=["value"], + commit_every=123, + replace=True, + fast_executemany=True, + autocommit=True, + ) + def test_dialect_name(self): assert self.db_hook.dialect_name == "postgresql"