Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions providers/common/ai/docs/toolsets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,20 @@ Curated toolset wrapping
The ``DbApiHook`` is resolved lazily from ``db_conn_id`` on first tool call
via ``BaseHook.get_connection(conn_id).get_hook()``.

In read-only mode (``allow_writes=False``, the default) the ``query`` tool also
accepts read-only metadata statements -- ``DESCRIBE``/``DESC`` and ``SHOW`` --
in addition to SELECT-family queries. Agents commonly open with ``DESCRIBE`` to
learn a table's columns, so permitting it keeps runs deterministic instead of
hard-failing on schema discovery. The toolset passes the connection's dialect to
the validator, so ``SHOW`` is recognized on databases that support it (Snowflake,
MySQL, etc.); on databases without ``SHOW`` it stays rejected. Data-modifying
statements remain blocked -- including ones hidden behind ``DESCRIBE``/``EXPLAIN``
(e.g. ``EXPLAIN DELETE ...``, ``DESCRIBE DROP TABLE ...``), which the validator
rejects by scanning the parsed statement for write operations. Like ``SELECT``,
metadata statements are not scoped by ``allowed_tables`` (see
:ref:`allowed-tables-limitation`) -- an agent can ``DESCRIBE`` a table outside the
list, so rely on database permissions to restrict access.

Multi-schema warehouses
^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -180,7 +194,8 @@ Parameters
- ``schema``: Default schema/namespace for unqualified table listing and
introspection. Schema-qualified ``allowed_tables`` entries override it per table.
- ``allow_writes``: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.).
Default ``False`` — only SELECT-family statements are permitted.
Default ``False`` -- only SELECT-family and read-only metadata
(``DESCRIBE``/``SHOW``) statements are permitted.
- ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.

``DataFusionToolset``
Expand Down Expand Up @@ -534,7 +549,9 @@ No single layer is sufficient — they work together.
- Does not restrict what arguments the agent passes to allowed methods.
* - **SQLToolset: read-only by default**
- ``allow_writes=False`` (default) validates every SQL query through
``validate_sql()`` and rejects INSERT, UPDATE, DELETE, DROP, etc.
``validate_sql()``: SELECT-family and read-only metadata
(``DESCRIBE``/``SHOW``) statements pass; INSERT, UPDATE, DELETE, DROP,
and writes hidden behind ``EXPLAIN`` are rejected.
- Does not prevent the agent from reading sensitive data that the
database user has SELECT access to.
* - **DataFusionToolset: read-only by default**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
try:
from airflow.providers.common.ai.utils.sql_validation import (
DEFAULT_ALLOWED_TYPES,
resolve_sqlglot_dialect,
validate_sql as _validate_sql,
)
from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
Expand All @@ -44,12 +45,6 @@
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.sdk import Context

# SQLAlchemy dialect_name → sqlglot dialect mapping for names that differ.
_SQLALCHEMY_TO_SQLGLOT_DIALECT: dict[str, str] = {
"postgresql": "postgres",
"mssql": "tsql",
}


class LLMSQLQueryOperator(LLMOperator):
"""
Expand Down Expand Up @@ -257,6 +252,4 @@ def _resolved_dialect(self) -> str | None:
raw = self.dialect
if not raw and self.db_hook and hasattr(self.db_hook, "dialect_name"):
raw = self.db_hook.dialect_name
if raw:
return _SQLALCHEMY_TO_SQLGLOT_DIALECT.get(raw, raw)
return None
return resolve_sqlglot_dialect(raw)
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from typing import TYPE_CHECKING, Any

try:
from airflow.providers.common.ai.utils.sql_validation import validate_sql as _validate_sql
from airflow.providers.common.ai.utils.sql_validation import (
resolve_sqlglot_dialect,
validate_sql as _validate_sql,
)
from airflow.providers.common.sql.hooks.sql import DbApiHook
except ImportError as e:
from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException
Expand Down Expand Up @@ -297,11 +300,23 @@ def _get_schema(self, table_name: str) -> str:
columns = hook.get_table_schema(table, schema=schema)
return json.dumps(columns)

def _dialect_for_validation(self) -> str | None:
"""Resolve the hook's sqlglot dialect so DESCRIBE/SHOW validate correctly."""
hook = self._get_db_hook()
return resolve_sqlglot_dialect(getattr(hook, "dialect_name", None))

def _query(self, sql: str) -> str:
hook = self._get_db_hook()
if not self._allow_writes:
_validate_sql(sql)
# allow_read_only_metadata lets agents inspect schemas with DESCRIBE/SHOW
# (a common first move) instead of hard-failing; the deep scan still
# rejects any data-modifying statement, including EXPLAIN <write>.
_validate_sql(
sql,
dialect=self._dialect_for_validation(),
allow_read_only_metadata=True,
)

hook = self._get_db_hook()
try:
rows = hook.get_records(sql)
except Exception as e:
Expand Down Expand Up @@ -347,8 +362,13 @@ def _is_retryable_query_error(hook: DbApiHook, error: Exception) -> bool:
return False

def _check_query(self, sql: str) -> str:
# Resolve the dialect best-effort: if the connection can't be reached we
# still syntax-check dialect-agnostically rather than reporting invalid.
dialect: str | None = None
with suppress(Exception):
dialect = self._dialect_for_validation()
try:
_validate_sql(sql)
_validate_sql(sql, dialect=dialect, allow_read_only_metadata=True)
return json.dumps({"valid": True})
except Exception as e:
return json.dumps({"valid": False, "error": str(e)})
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,36 @@

import sqlglot
from sqlglot import exp
from sqlglot.dialects import Dialects
from sqlglot.errors import ErrorLevel

# Dialect names sqlglot recognizes. Used to drop unknown dialect names so a bad
# value never breaks parsing (sqlglot raises on an unknown dialect).
_KNOWN_SQLGLOT_DIALECTS: frozenset[str] = frozenset(d.value for d in Dialects)

# SQLAlchemy ``dialect_name`` → sqlglot dialect mapping for names that differ.
_SQLALCHEMY_TO_SQLGLOT_DIALECT: dict[str, str] = {
"postgresql": "postgres",
"mssql": "tsql",
}


def resolve_sqlglot_dialect(dialect_name: str | None) -> str | None:
"""
Normalize a SQLAlchemy dialect name to a sqlglot dialect.

Returns ``None`` (dialect-agnostic parsing) for empty, non-string, or
unknown inputs, so a bad dialect value never breaks SQL validation.

:param dialect_name: A SQLAlchemy ``dialect_name`` (e.g. ``"postgresql"``).
:return: The matching sqlglot dialect (e.g. ``"postgres"``), or ``None``.
"""
if not isinstance(dialect_name, str) or not dialect_name:
return None
mapped = _SQLALCHEMY_TO_SQLGLOT_DIALECT.get(dialect_name, dialect_name)
return mapped if mapped in _KNOWN_SQLGLOT_DIALECTS else None


# Allowlist: only these top-level statement types pass validation by default.
# - Select: plain queries and CTE-wrapped queries (WITH ... AS ... SELECT is parsed
# as Select with a `with` clause property — still a Select node at the top level)
Expand All @@ -39,9 +67,21 @@
exp.Except,
)

# Read-only metadata statements that introspect the schema without touching data:
# - Describe: DESCRIBE / DESC <table> (and EXPLAIN on some dialects)
# - Show: SHOW TABLES / SHOW COLUMNS / SHOW DATABASES, etc.
# Opt-in via ``allow_read_only_metadata=True``. SHOW only parses to ``exp.Show``
# when a dialect that supports it is passed (e.g. snowflake, mysql); without a
# dialect sqlglot falls back to ``exp.Command``, which stays blocked.
READ_ONLY_METADATA_TYPES: tuple[type[exp.Expr], ...] = (
exp.Describe,
exp.Show,
)

# Denylist: expression types that mutate data or schema when found anywhere in the AST.
# This catches data-modifying CTEs (e.g. WITH del AS (DELETE …) SELECT …),
# SELECT INTO, and other constructs that bypass top-level type checks.
# SELECT INTO, DDL or DML wrapped behind DESCRIBE/EXPLAIN (e.g. DESCRIBE DROP TABLE …),
# and other constructs that bypass top-level type checks.
# Note: exp.Command is sqlglot's fallback for any syntax it doesn't recognize.
# Including it makes the denylist fail-closed (safer), but may block legitimate
# vendor-specific SQL that sqlglot can't parse. Callers who need such syntax can
Expand All @@ -53,6 +93,11 @@
exp.Merge,
exp.Into,
exp.Command,
# DDL — newly reachable through the DESCRIBE/SHOW allowlist, so deny it here too.
exp.Create,
exp.Drop,
exp.Alter,
exp.TruncateTable,
)


Expand All @@ -66,6 +111,7 @@ def validate_sql(
allowed_types: tuple[type[exp.Expr], ...] | None = None,
dialect: str | None = None,
allow_multiple_statements: bool = False,
allow_read_only_metadata: bool = False,
) -> list[exp.Expr]:
"""
Parse SQL and verify all statements are in the allowed types list.
Expand All @@ -78,18 +124,35 @@ def validate_sql(

:param sql: SQL string to validate.
:param allowed_types: Tuple of sqlglot expression types to permit.
Defaults to ``(Select, Union, Intersect, Except)``.
Defaults to ``(Select, Union, Intersect, Except)``. When supplied, the
caller takes full control of the allow-list and ``allow_read_only_metadata``
is ignored.
:param dialect: SQL dialect for parsing (``postgres``, ``mysql``, etc.).
:param allow_multiple_statements: Whether to allow multiple semicolon-separated
statements. Default ``False``.
:param allow_read_only_metadata: Also permit read-only metadata statements
(``DESCRIBE``/``SHOW``) on top of the default read-only allow-list. Ignored
when ``allowed_types`` is supplied. Note ``SHOW`` only parses to a metadata
statement when a ``dialect`` that supports it is given. Default ``False``.
:return: List of parsed sqlglot Expression objects.
:raises SQLSafetyError: If the SQL is empty, contains disallowed statement types,
or has multiple statements when not permitted.
"""
if not sql or not sql.strip():
raise SQLSafetyError("Empty SQL input.")

types = allowed_types or DEFAULT_ALLOWED_TYPES
# A caller-supplied ``allowed_types`` is an explicit opt-out of the curated
# read-only defaults (and the data-modifying deep scan). Otherwise we use the
# read-only defaults, optionally widened with metadata statements, and keep
# the deep scan on.
if allowed_types is None:
types: tuple[type[exp.Expr], ...] = DEFAULT_ALLOWED_TYPES
if allow_read_only_metadata:
types = types + READ_ONLY_METADATA_TYPES
run_data_modifying_scan = True
else:
types = allowed_types
run_data_modifying_scan = types == DEFAULT_ALLOWED_TYPES

try:
statements = sqlglot.parse(sql, dialect=dialect, error_level=ErrorLevel.RAISE)
Expand All @@ -114,10 +177,10 @@ def validate_sql(
)

# Deep scan: reject data-modifying nodes hidden inside otherwise-allowed statements
# (e.g. data-modifying CTEs, SELECT INTO). Only applies when using the default
# read-only allowlist — callers who provide custom allowed_types have explicitly
# (e.g. data-modifying CTEs, SELECT INTO, EXPLAIN <write>). Runs for the curated
# read-only allow-list — callers who provide custom allowed_types have explicitly
# opted into non-read-only operations.
if types is DEFAULT_ALLOWED_TYPES:
if run_data_modifying_scan:
_check_for_data_modifying_nodes(parsed)

return parsed
Expand Down
62 changes: 62 additions & 0 deletions providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,65 @@ def test_list_tables_deduplicates_same_table(self):

result = json.loads(asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), tool=MagicMock())))
assert result == ["public.users"]


class TestSQLToolsetMetadataStatements:
"""Read-only metadata statements (DESCRIBE/SHOW) flow through the query tool."""

def test_describe_allowed_through_query(self):
"""DESCRIBE is read-only metadata and should not be rejected as unsafe."""
ts = SQLToolset("pg_default")
ts._hook = _make_mock_db_hook(
records=[("id", "INTEGER"), ("name", "VARCHAR")],
last_description=[("column_name",), ("data_type",)],
)

result = asyncio.run(
ts.call_tool("query", {"sql": "DESCRIBE TABLE users"}, ctx=MagicMock(), tool=MagicMock())
)
data = json.loads(result)
assert "rows" in data
ts._hook.get_records.assert_called_once_with("DESCRIBE TABLE users")

def test_show_allowed_with_snowflake_dialect(self):
"""SHOW parses to a metadata statement once the hook's dialect is passed through."""
ts = SQLToolset("sf_default")
ts._hook = _make_mock_db_hook(records=[("USERS",)], last_description=[("name",)])
ts._hook.dialect_name = "snowflake"

result = asyncio.run(ts.call_tool("query", {"sql": "SHOW TABLES"}, ctx=MagicMock(), tool=MagicMock()))
data = json.loads(result)
assert "rows" in data
ts._hook.get_records.assert_called_once_with("SHOW TABLES")

@pytest.mark.parametrize(
"sql",
# SHOW falls back to Command on Postgres (no SHOW support); DELETE is a write.
["SHOW TABLES", "DELETE FROM users"],
ids=["show_without_dialect_support", "write"],
)
def test_query_blocks_disallowed_statements(self, sql):
ts = SQLToolset("pg_default")
ts._hook = _make_mock_db_hook()
ts._hook.dialect_name = "postgresql"

with pytest.raises(SQLSafetyError, match="not allowed"):
asyncio.run(ts.call_tool("query", {"sql": sql}, ctx=MagicMock(), tool=MagicMock()))

def test_check_query_accepts_describe(self):
ts = SQLToolset("pg_default")
ts._hook = _make_mock_db_hook()

result = asyncio.run(
ts.call_tool("check_query", {"sql": "DESCRIBE TABLE users"}, ctx=MagicMock(), tool=MagicMock())
)
assert json.loads(result)["valid"] is True

def test_check_query_handles_unresolvable_connection(self):
"""check_query stays usable (dialect-agnostic) when the connection can't be resolved."""
ts = SQLToolset("missing_conn")
with patch.object(ts, "_get_db_hook", side_effect=RuntimeError("no such connection")):
result = asyncio.run(
ts.call_tool("check_query", {"sql": "SELECT 1"}, ctx=MagicMock(), tool=MagicMock())
)
assert json.loads(result)["valid"] is True
Loading
Loading