diff --git a/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py b/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py index 446ea199fa154..3635cca98c396 100644 --- a/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py @@ -106,9 +106,11 @@ def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: :param table: Name of the target table :param values: The row to insert into the table :param target_fields: The names of the columns to fill in the table - :param replace: Whether to replace instead of insert :param replace_index: the column or list of column names to act as index for the ON CONFLICT clause + :param replace_target: Column name or list of column names to update when + a conflict occurs. If omitted, all non-conflict columns are updated. + If an empty list is provided, ``DO NOTHING`` is used. :return: The generated INSERT or REPLACE SQL statement """ if not target_fields: @@ -124,7 +126,13 @@ def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: sql = self.generate_insert_sql(table, values, target_fields, **kwargs) on_conflict_str = f" ON CONFLICT ({', '.join(map(self.escape_word, replace_index))})" - replace_target = [self.escape_word(f) for f in target_fields if f not in replace_index] + + replace_target = kwargs.get("replace_target") + + if replace_target is None: + replace_target = [self.escape_word(f) for f in target_fields if f not in replace_index] + else: + replace_target = [self.escape_word(f) for f in replace_target] if replace_target: replace_target_str = ", ".join(f"{col} = excluded.{col}" for col in replace_target) diff --git a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py index ec28bcc834fed..7241eb1a4af59 100644 --- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py @@ -18,7 +18,7 @@ from __future__ import annotations import os -from collections.abc import Mapping +from collections.abc import Iterable, Mapping from contextlib import closing from copy import deepcopy from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, cast, overload @@ -726,3 +726,42 @@ def insert_rows( self.log.info("Done loading. Loaded a total of %s rows into %s", nb_rows, table) return None + + def upsert_rows( + self, + table: str, + rows: Iterable[tuple[Any, ...]], + target_fields: list[str], + conflict_fields: list[str], + update_fields: list[str] | None = None, + commit_every: int = 1000, + *, + fast_executemany: bool = False, + autocommit: bool = False, + ) -> None: + """ + Upsert rows into a PostgreSQL table using ``ON CONFLICT``. + + :param table: Name of the target table. + :param rows: Rows to upsert. + :param target_fields: Non-empty column names used in the ``INSERT`` statement. + :param conflict_fields: Non-empty column names used in the ``ON CONFLICT`` clause. + :param update_fields: Columns updated on conflict. If omitted, all + non-conflict columns are updated. If an empty list is provided, + conflicting rows are ignored via ``DO NOTHING``. + :param commit_every: Maximum number of rows per transaction. Default value is 1000. + :param fast_executemany: Use ``psycopg2.extras.execute_batch`` for improved + batch performance. + :param autocommit: Connection autocommit setting. + """ + return self.insert_rows( + table=table, + rows=rows, + target_fields=target_fields, + replace_index=conflict_fields, + replace_target=update_fields, + commit_every=commit_every, + replace=True, + fast_executemany=fast_executemany, + autocommit=autocommit, + ) diff --git a/providers/postgres/tests/unit/postgres/dialects/test_postgres.py b/providers/postgres/tests/unit/postgres/dialects/test_postgres.py index 999386baac12e..946805ec40fbd 100644 --- a/providers/postgres/tests/unit/postgres/dialects/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/dialects/test_postgres.py @@ -19,6 +19,8 @@ from unittest.mock import MagicMock +import pytest + from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.postgres.dialects.postgres import PostgresDialect @@ -103,3 +105,58 @@ def test_generate_replace_sql_when_escape_column_names_is_enabled(self): INSERT INTO hollywood.actors ("id", "name", "firstname", "age") VALUES (?,?,?,?,?) ON CONFLICT ("id") DO UPDATE SET "name" = excluded."name", "firstname" = excluded."firstname", "age" = excluded."age" """.strip() ) + + @pytest.mark.parametrize( + ("replace_index", "replace_target", "expected_clause"), + [ + ( + None, + ["name"], + "ON CONFLICT (id) DO UPDATE SET name = excluded.name", + ), + ( + None, + ["name", "age"], + "ON CONFLICT (id) DO UPDATE SET name = excluded.name, age = excluded.age", + ), + ( + None, + [], + "ON CONFLICT (id) DO NOTHING", + ), + ( + ["id", "name"], + ["age"], + "ON CONFLICT (id, name) DO UPDATE SET age = excluded.age", + ), + ], + ) + def test_generate_replace_sql_with_replace_target( + self, + replace_index, + replace_target, + expected_clause, + ): + values = [ + {"id": 1, "name": "Stallone", "firstname": "Sylvester", "age": "78"}, + {"id": 2, "name": "Statham", "firstname": "Jason", "age": "57"}, + {"id": 3, "name": "Li", "firstname": "Jet", "age": "61"}, + {"id": 4, "name": "Lundgren", "firstname": "Dolph", "age": "66"}, + {"id": 5, "name": "Norris", "firstname": "Chuck", "age": "84"}, + ] + + target_fields = ["id", "name", "firstname", "age"] + + sql = PostgresDialect(self.test_db_hook).generate_replace_sql( + "hollywood.actors", + values, + target_fields, + replace_index=replace_index, + replace_target=replace_target, + ) + + assert ( + sql + == f""" + INSERT INTO hollywood.actors (id, name, firstname, age) VALUES (?,?,?,?,?) {expected_clause}""".strip() + ) diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py index fab8bc6f5d7d5..7825d98813209 100644 --- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py @@ -911,6 +911,35 @@ def test_insert_rows_replace_all_index(self): ) self.cur.executemany.assert_any_call(sql, rows) + @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.insert_rows") + def test_upsert_rows(self, mock_insert_rows): + + rows = [(1, "hello")] + table = "table" + + self.db_hook.upsert_rows( + table=table, + rows=rows, + target_fields=["id", "value"], + conflict_fields=["id"], + update_fields=["value"], + commit_every=123, + fast_executemany=True, + autocommit=True, + ) + + mock_insert_rows.assert_called_once_with( + table=table, + rows=rows, + target_fields=["id", "value"], + replace_index=["id"], + replace_target=["value"], + commit_every=123, + replace=True, + fast_executemany=True, + autocommit=True, + ) + def test_dialect_name(self): assert self.db_hook.dialect_name == "postgresql"