Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str:
:param table: Name of the target table
:param values: The row to insert into the table
:param target_fields: The names of the columns to fill in the table
:param replace: Whether to replace instead of insert
:param replace_index: the column or list of column names to act as
index for the ON CONFLICT clause
:param replace_target: Column name or list of column names to update when
a conflict occurs. If omitted, all non-conflict columns are updated.
If an empty list is provided, ``DO NOTHING`` is used.
:return: The generated INSERT or REPLACE SQL statement
"""
if not target_fields:
Expand All @@ -124,7 +126,13 @@ def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str:

sql = self.generate_insert_sql(table, values, target_fields, **kwargs)
on_conflict_str = f" ON CONFLICT ({', '.join(map(self.escape_word, replace_index))})"
replace_target = [self.escape_word(f) for f in target_fields if f not in replace_index]

replace_target = kwargs.get("replace_target")

if replace_target is None:
replace_target = [self.escape_word(f) for f in target_fields if f not in replace_index]
else:
replace_target = [self.escape_word(f) for f in replace_target]
Comment thread
SameerMesiah97 marked this conversation as resolved.

if replace_target:
replace_target_str = ", ".join(f"{col} = excluded.{col}" for col in replace_target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import os
from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from contextlib import closing
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, cast, overload
Expand Down Expand Up @@ -726,3 +726,42 @@ def insert_rows(

self.log.info("Done loading. Loaded a total of %s rows into %s", nb_rows, table)
return None

def upsert_rows(
self,
table: str,
rows: Iterable[tuple[Any, ...]],
target_fields: list[str],
conflict_fields: list[str],
update_fields: list[str] | None = None,
commit_every: int = 1000,
*,
fast_executemany: bool = False,
autocommit: bool = False,
) -> None:
"""
Upsert rows into a PostgreSQL table using ``ON CONFLICT``.

:param table: Name of the target table.
:param rows: Rows to upsert.
:param target_fields: Non-empty column names used in the ``INSERT`` statement.
:param conflict_fields: Non-empty column names used in the ``ON CONFLICT`` clause.
:param update_fields: Columns updated on conflict. If omitted, all
non-conflict columns are updated. If an empty list is provided,
conflicting rows are ignored via ``DO NOTHING``.
:param commit_every: Maximum number of rows per transaction. Default value is 1000.
:param fast_executemany: Use ``psycopg2.extras.execute_batch`` for improved
batch performance.
:param autocommit: Connection autocommit setting.
"""
return self.insert_rows(
table=table,
rows=rows,
target_fields=target_fields,
replace_index=conflict_fields,
replace_target=update_fields,
commit_every=commit_every,
replace=True,
fast_executemany=fast_executemany,
autocommit=autocommit,
)
57 changes: 57 additions & 0 deletions providers/postgres/tests/unit/postgres/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from unittest.mock import MagicMock

import pytest

from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.postgres.dialects.postgres import PostgresDialect

Expand Down Expand Up @@ -103,3 +105,58 @@ def test_generate_replace_sql_when_escape_column_names_is_enabled(self):
INSERT INTO hollywood.actors ("id", "name", "firstname", "age") VALUES (?,?,?,?,?) ON CONFLICT ("id") DO UPDATE SET "name" = excluded."name", "firstname" = excluded."firstname", "age" = excluded."age"
""".strip()
)

@pytest.mark.parametrize(
("replace_index", "replace_target", "expected_clause"),
[
(
None,
["name"],
"ON CONFLICT (id) DO UPDATE SET name = excluded.name",
),
(
None,
["name", "age"],
"ON CONFLICT (id) DO UPDATE SET name = excluded.name, age = excluded.age",
),
(
None,
[],
"ON CONFLICT (id) DO NOTHING",
),
(
["id", "name"],
["age"],
"ON CONFLICT (id, name) DO UPDATE SET age = excluded.age",
),
],
)
def test_generate_replace_sql_with_replace_target(
self,
replace_index,
replace_target,
expected_clause,
):
values = [
{"id": 1, "name": "Stallone", "firstname": "Sylvester", "age": "78"},
{"id": 2, "name": "Statham", "firstname": "Jason", "age": "57"},
{"id": 3, "name": "Li", "firstname": "Jet", "age": "61"},
{"id": 4, "name": "Lundgren", "firstname": "Dolph", "age": "66"},
{"id": 5, "name": "Norris", "firstname": "Chuck", "age": "84"},
]

target_fields = ["id", "name", "firstname", "age"]

sql = PostgresDialect(self.test_db_hook).generate_replace_sql(
"hollywood.actors",
values,
target_fields,
replace_index=replace_index,
replace_target=replace_target,
)

assert (
sql
== f"""
INSERT INTO hollywood.actors (id, name, firstname, age) VALUES (?,?,?,?,?) {expected_clause}""".strip()
)
29 changes: 29 additions & 0 deletions providers/postgres/tests/unit/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,35 @@ def test_insert_rows_replace_all_index(self):
)
self.cur.executemany.assert_any_call(sql, rows)

@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.insert_rows")
def test_upsert_rows(self, mock_insert_rows):

rows = [(1, "hello")]
table = "table"

self.db_hook.upsert_rows(
table=table,
rows=rows,
target_fields=["id", "value"],
conflict_fields=["id"],
update_fields=["value"],
commit_every=123,
fast_executemany=True,
autocommit=True,
)

mock_insert_rows.assert_called_once_with(
table=table,
rows=rows,
target_fields=["id", "value"],
replace_index=["id"],
replace_target=["value"],
commit_every=123,
replace=True,
fast_executemany=True,
autocommit=True,
)

def test_dialect_name(self):
assert self.db_hook.dialect_name == "postgresql"

Expand Down
Loading