Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# 0.5.0
- paginated merge support via optional `merge_batch_size` parameter to bound memory per transaction on large datasets
- more efficient backend resolving
- logging namespace updated to reflect package name
- Python 3.13 support (updated oa-configurator dependency)

## 0.1.0
- initial commit
- stripped out generalisable functionality from omop-alchemy so that it could be reused in multiple clinical data models
Expand Down
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "orm-loader"
version = "0.4.1"
version = "0.5.0"
description = "Generic base classes to handle ORM functionality for multiple downstream datamodels"
readme = "README.md"
authors = [
Expand All @@ -9,6 +9,7 @@ authors = [
requires-python = ">=3.12"
dependencies = [
"chardet>=5.2.0",
"oa-configurator>=0.1.1",
"pandas>=2.3.3",
"pyarrow>=23.0.0",
"sqlalchemy>=2.0.45",
Expand Down Expand Up @@ -41,6 +42,7 @@ postgres = [
"psycopg[binary]>=3.2",
]
dev = [
"oa-configurator[postgres]>=0.1.1",
"pytest>=9.0.3",
"mypy>=1.19.1",
"ruff>=0.14.11",
Expand All @@ -53,6 +55,9 @@ dev = [
"python-dotenv"
]

[project.entry-points."omop.config"]
orm_loader = "orm_loader.config:OrmLoaderConfig"

[tool.setuptools]
packages = ["orm_loader"]

Expand All @@ -70,9 +75,6 @@ python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-ra"
markers = [
"postgres: requires a running Postgres instance (set TEST_POSTGRES_URL)",
]

[tool.pyright]
reportMissingTypeStubs = false
6 changes: 6 additions & 0 deletions src/orm_loader/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def merge_replace(
target_name: str,
staging_name: str,
pk_cols: list[str],
*,
merge_batch_size: int | None = None,
) -> None:
"""Merge staging rows by replacing matching target rows first."""

Expand All @@ -192,6 +194,8 @@ def merge_upsert(
target_name: str,
staging_name: str,
pk_cols: list[str],
*,
merge_batch_size: int | None = None,
) -> None:
"""Merge staging rows using backend-specific upsert semantics."""

Expand All @@ -202,6 +206,8 @@ def merge_insert(
session: so.Session,
target_name: str,
staging_name: str,
*,
merge_batch_size: int | None = None,
) -> None:
"""Insert all staging rows into the target table."""

Expand Down
152 changes: 115 additions & 37 deletions src/orm_loader/backends/postgres.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from contextlib import contextmanager, AbstractContextManager
from contextlib import AbstractContextManager, contextmanager
from typing import TYPE_CHECKING, Any

import sqlalchemy as sa
import sqlalchemy.orm as so
import sqlalchemy.event as sae
import sqlalchemy.orm as so

from .base import BackendCapabilities, DatabaseBackend, Dialect
from ..loaders.loading_helpers import quick_load_pg
from .base import BackendCapabilities, DatabaseBackend, Dialect

if TYPE_CHECKING:
from sqlalchemy.engine import Connection, Engine
Expand Down Expand Up @@ -57,6 +58,14 @@ def create_staging_table(
for col in computed_cols:
session.execute(sa.text(f'ALTER TABLE "{staging_name}" DROP COLUMN "{col}";'))

# allows pagniation in O(N log N) time for large tables in merge_insert without needing to add an index on every staging table
session.execute(
sa.text(
f'ALTER TABLE "{staging_name}" ADD COLUMN _rownum BIGINT'
f" GENERATED ALWAYS AS IDENTITY (CACHE 1000);"
)
)

session.commit()

def drop_staging_table(
Expand Down Expand Up @@ -124,19 +133,39 @@ def merge_replace(
target_name: str,
staging_name: str,
pk_cols: list[str],
*,
merge_batch_size: int | None = None,
) -> None:
pk_join = " AND ".join(
f't."{c}" = s."{c}"' for c in pk_cols
pk_join = " AND ".join(f't."{c}" = s."{c}"' for c in pk_cols)

non_paginated_replace = sa.text(
f'DELETE FROM "{target_name}" t USING "{staging_name}" s WHERE {pk_join}'
)
session.execute(
sa.text(
f"""
DELETE FROM "{target_name}" t
USING "{staging_name}" s
WHERE {pk_join};
"""

if merge_batch_size is None:
session.execute(non_paginated_replace)
return

total = session.execute(sa.text(f'SELECT COUNT(*) FROM "{staging_name}"')).scalar_one()
if total <= merge_batch_size:
session.execute(non_paginated_replace)
return

session.execute(sa.text(f'CREATE INDEX IF NOT EXISTS "{staging_name}_rownum_idx" ON "{staging_name}" (_rownum)'))
session.commit()

start = 0
while start < total:
end = start + merge_batch_size
session.execute(
sa.text(
f'DELETE FROM "{target_name}" t USING "{staging_name}" s'
f' WHERE {pk_join} AND s._rownum > :start AND s._rownum <= :end'
),
{"start": start, "end": end},
)
)
session.commit()
start = end

def merge_upsert(
self,
Expand All @@ -145,47 +174,100 @@ def merge_upsert(
target_name: str,
staging_name: str,
pk_cols: list[str],
*,
merge_batch_size: int | None = None,
) -> None:
insertable_cols = self._insertable_column_names(table_cls)
cols_str = ", ".join(f'"{c}"' for c in insertable_cols)
conflict_cols = ", ".join(f'"{c}"' for c in pk_cols)
session.execute(
sa.text(
f"""
INSERT INTO "{target_name}" ({cols_str})
SELECT {cols_str} FROM "{staging_name}"
ON CONFLICT ({conflict_cols}) DO NOTHING;
"""
)

non_paginated_upsert = sa.text(
f'INSERT INTO "{target_name}" ({cols_str})'
f' SELECT {cols_str} FROM "{staging_name}"'
f' ON CONFLICT ({conflict_cols}) DO NOTHING'
)

if merge_batch_size is None:
session.execute(non_paginated_upsert)
return

total = session.execute(sa.text(f'SELECT COUNT(*) FROM "{staging_name}"')).scalar_one()
if total <= merge_batch_size:
session.execute(non_paginated_upsert)
return

session.execute(sa.text(f'CREATE INDEX IF NOT EXISTS "{staging_name}_rownum_idx" ON "{staging_name}" (_rownum)'))
session.commit()

start = 0
while start < total:
end = start + merge_batch_size
session.execute(
sa.text(
f'INSERT INTO "{target_name}" ({cols_str})'
f' SELECT {cols_str} FROM "{staging_name}"'
f' WHERE _rownum > :start AND _rownum <= :end'
f' ON CONFLICT ({conflict_cols}) DO NOTHING'
),
{"start": start, "end": end},
)
session.commit()
start = end

def merge_insert(
self,
table_cls: type["CSVTableProtocol"],
session: so.Session,
target_name: str,
staging_name: str,
*,
merge_batch_size: int | None = None,
) -> None:
insertable_cols = self._insertable_column_names(table_cls)
cols_str = ", ".join(f'"{c}"' for c in insertable_cols)
session.execute(
sa.text(
f"""
INSERT INTO "{target_name}" ({cols_str})
SELECT {cols_str} FROM "{staging_name}";
"""
)

non_paginated_insert = sa.text(
f'INSERT INTO "{target_name}" ({cols_str})'
f' SELECT {cols_str} FROM "{staging_name}"'
)

if merge_batch_size is None:
session.execute(non_paginated_insert)
return

total = session.execute(sa.text(f'SELECT COUNT(*) FROM "{staging_name}"')).scalar_one()
if total <= merge_batch_size:
session.execute(non_paginated_insert)
return

# Paginated path: index _rownum for O(N log N) range scans then
# INSERT in batch-sized transactions to bound WAL per commit.
# session_replication_role='replica' is session-level and persists
# across commits, so FK checks stay disabled for all batches.
session.execute(sa.text(f'CREATE INDEX IF NOT EXISTS "{staging_name}_rownum_idx" ON "{staging_name}" (_rownum)'))
session.commit()

start = 0
while start < total:
end = start + merge_batch_size
session.execute(
sa.text(
f'INSERT INTO "{target_name}" ({cols_str})'
f' SELECT {cols_str} FROM "{staging_name}"'
f' WHERE _rownum > :start AND _rownum <= :end'
),
{"start": start, "end": end},
)
session.commit()
start = end

def merge_context(
self,
table_cls: type["CSVTableProtocol"],
session: so.Session,
) -> AbstractContextManager[None]:
return self.bulk_load_context(session, disable_fk=True, no_autoflush=False)



def create_materialized_view(
self,
bind: Engine | Connection,
Expand All @@ -196,7 +278,7 @@ def create_materialized_view(

with self._as_connection(bind) as conn:
conn.execute(CreateMaterializedView(name, selectable))

def refresh_materialized_view(
self,
bind: Engine | Connection,
Expand All @@ -207,9 +289,7 @@ def refresh_materialized_view(
dialect = getattr(conn, "dialect", None)
if dialect is not None:
safe_name = dialect.identifier_preparer.quote(name)
conn.execute(
sa.text(f"REFRESH MATERIALIZED VIEW {safe_name};")
)
conn.execute(sa.text(f"REFRESH MATERIALIZED VIEW {safe_name};"))

@contextmanager
def engine_with_replica_role(self, engine: "Engine"):
Expand All @@ -230,8 +310,6 @@ def _set_replica_role(
with engine.connect() as conn:
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
conn.execute(sa.text("SET session_replication_role = DEFAULT"))
role = conn.execute(
sa.text("SHOW session_replication_role")
).scalar()
role = conn.execute(sa.text("SHOW session_replication_role")).scalar()
if role != "origin":
raise RuntimeError("Failed to restore session_replication_role")
23 changes: 10 additions & 13 deletions src/orm_loader/backends/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from sqlalchemy.engine import Connection, Engine


_BACKEND_TYPES: tuple[type[DatabaseBackend], ...] = (
PostgresBackend,
SQLiteBackend,
)
_BACKEND_TYPES: dict[Dialect, type[DatabaseBackend]] = {
Dialect.POSTGRESQL: PostgresBackend,
Dialect.SQLITE: SQLiteBackend,
}


def _dialect(bindable: "so.Session | Engine | Connection") -> Dialect:
Expand All @@ -34,13 +34,10 @@ def _dialect(bindable: "so.Session | Engine | Connection") -> Dialect:
) from exc


def resolve_backend(bindable: "so.Session | Engine | Connection") -> DatabaseBackend:
"""
Resolve a concrete backend from a SQLAlchemy session, engine, or connection.
"""
def resolve_backend(bindable: "so.Session | Engine | Connection", **kwargs) -> DatabaseBackend:
"""Resolve a concrete backend from a SQLAlchemy session, engine, or connection."""
dialect = _dialect(bindable)
for backend_type in _BACKEND_TYPES:
backend = backend_type()
if backend.supports_dialect(dialect):
return backend
raise NotImplementedError(f"No backend registered for dialect '{dialect.value}'")
try:
return _BACKEND_TYPES[dialect](**kwargs)
except KeyError:
raise NotImplementedError(f"No backend registered for dialect '{dialect.value}'")
6 changes: 6 additions & 0 deletions src/orm_loader/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def merge_replace(
target_name: str,
staging_name: str,
pk_cols: list[str],
*,
merge_batch_size: int | None = None,
) -> None:
if len(pk_cols) == 1:
pk = pk_cols[0]
Expand Down Expand Up @@ -176,6 +178,8 @@ def merge_upsert(
target_name: str,
staging_name: str,
pk_cols: list[str],
*,
merge_batch_size: int | None = None,
) -> None:
insertable_cols = self._insertable_column_names(table_cls)
cols_str = ", ".join(f'"{c}"' for c in insertable_cols)
Expand All @@ -194,6 +198,8 @@ def merge_insert(
session: so.Session,
target_name: str,
staging_name: str,
*,
merge_batch_size: int | None = None,
) -> None:
insertable_cols = self._insertable_column_names(table_cls)
cols_str = ", ".join(f'"{c}"' for c in insertable_cols)
Expand Down
Loading
Loading