diff --git a/providers/common/ai/docs/toolsets.rst b/providers/common/ai/docs/toolsets.rst index 617c63520b113..b5e868abea209 100644 --- a/providers/common/ai/docs/toolsets.rst +++ b/providers/common/ai/docs/toolsets.rst @@ -146,6 +146,20 @@ Curated toolset wrapping The ``DbApiHook`` is resolved lazily from ``db_conn_id`` on first tool call via ``BaseHook.get_connection(conn_id).get_hook()``. +In read-only mode (``allow_writes=False``, the default) the ``query`` tool also +accepts read-only metadata statements -- ``DESCRIBE``/``DESC`` and ``SHOW`` -- +in addition to SELECT-family queries. Agents commonly open with ``DESCRIBE`` to +learn a table's columns, so permitting it keeps runs deterministic instead of +hard-failing on schema discovery. The toolset passes the connection's dialect to +the validator, so ``SHOW`` is recognized on databases that support it (Snowflake, +MySQL, etc.); on databases without ``SHOW`` it stays rejected. Data-modifying +statements remain blocked -- including ones hidden behind ``DESCRIBE``/``EXPLAIN`` +(e.g. ``EXPLAIN DELETE ...``, ``DESCRIBE DROP TABLE ...``), which the validator +rejects by scanning the parsed statement for write operations. Like ``SELECT``, +metadata statements are not scoped by ``allowed_tables`` (see +:ref:`allowed-tables-limitation`) -- an agent can ``DESCRIBE`` a table outside the +list, so rely on database permissions to restrict access. + Multi-schema warehouses ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -180,7 +194,8 @@ Parameters - ``schema``: Default schema/namespace for unqualified table listing and introspection. Schema-qualified ``allowed_tables`` entries override it per table. - ``allow_writes``: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.). - Default ``False`` — only SELECT-family statements are permitted. + Default ``False`` -- only SELECT-family and read-only metadata + (``DESCRIBE``/``SHOW``) statements are permitted. - ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``. ``DataFusionToolset`` @@ -534,7 +549,9 @@ No single layer is sufficient — they work together. - Does not restrict what arguments the agent passes to allowed methods. * - **SQLToolset: read-only by default** - ``allow_writes=False`` (default) validates every SQL query through - ``validate_sql()`` and rejects INSERT, UPDATE, DELETE, DROP, etc. + ``validate_sql()``: SELECT-family and read-only metadata + (``DESCRIBE``/``SHOW``) statements pass; INSERT, UPDATE, DELETE, DROP, + and writes hidden behind ``EXPLAIN`` are rejected. - Does not prevent the agent from reading sensitive data that the database user has SELECT access to. * - **DataFusionToolset: read-only by default** diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py index 370c819fa835e..7342be2b7e34f 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py @@ -25,6 +25,7 @@ try: from airflow.providers.common.ai.utils.sql_validation import ( DEFAULT_ALLOWED_TYPES, + resolve_sqlglot_dialect, validate_sql as _validate_sql, ) from airflow.providers.common.sql.datafusion.engine import DataFusionEngine @@ -44,12 +45,6 @@ from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.sdk import Context -# SQLAlchemy dialect_name → sqlglot dialect mapping for names that differ. -_SQLALCHEMY_TO_SQLGLOT_DIALECT: dict[str, str] = { - "postgresql": "postgres", - "mssql": "tsql", -} - class LLMSQLQueryOperator(LLMOperator): """ @@ -257,6 +252,4 @@ def _resolved_dialect(self) -> str | None: raw = self.dialect if not raw and self.db_hook and hasattr(self.db_hook, "dialect_name"): raw = self.db_hook.dialect_name - if raw: - return _SQLALCHEMY_TO_SQLGLOT_DIALECT.get(raw, raw) - return None + return resolve_sqlglot_dialect(raw) diff --git a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py index fca07177597e9..ee3128705a1f9 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py +++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py @@ -24,7 +24,10 @@ from typing import TYPE_CHECKING, Any try: - from airflow.providers.common.ai.utils.sql_validation import validate_sql as _validate_sql + from airflow.providers.common.ai.utils.sql_validation import ( + resolve_sqlglot_dialect, + validate_sql as _validate_sql, + ) from airflow.providers.common.sql.hooks.sql import DbApiHook except ImportError as e: from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException @@ -297,11 +300,23 @@ def _get_schema(self, table_name: str) -> str: columns = hook.get_table_schema(table, schema=schema) return json.dumps(columns) + def _dialect_for_validation(self) -> str | None: + """Resolve the hook's sqlglot dialect so DESCRIBE/SHOW validate correctly.""" + hook = self._get_db_hook() + return resolve_sqlglot_dialect(getattr(hook, "dialect_name", None)) + def _query(self, sql: str) -> str: + hook = self._get_db_hook() if not self._allow_writes: - _validate_sql(sql) + # allow_read_only_metadata lets agents inspect schemas with DESCRIBE/SHOW + # (a common first move) instead of hard-failing; the deep scan still + # rejects any data-modifying statement, including EXPLAIN . + _validate_sql( + sql, + dialect=self._dialect_for_validation(), + allow_read_only_metadata=True, + ) - hook = self._get_db_hook() try: rows = hook.get_records(sql) except Exception as e: @@ -347,8 +362,13 @@ def _is_retryable_query_error(hook: DbApiHook, error: Exception) -> bool: return False def _check_query(self, sql: str) -> str: + # Resolve the dialect best-effort: if the connection can't be reached we + # still syntax-check dialect-agnostically rather than reporting invalid. + dialect: str | None = None + with suppress(Exception): + dialect = self._dialect_for_validation() try: - _validate_sql(sql) + _validate_sql(sql, dialect=dialect, allow_read_only_metadata=True) return json.dumps({"valid": True}) except Exception as e: return json.dumps({"valid": False, "error": str(e)}) diff --git a/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py b/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py index 3ab87516ff641..a00b4dc11e62d 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py +++ b/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py @@ -26,8 +26,36 @@ import sqlglot from sqlglot import exp +from sqlglot.dialects import Dialects from sqlglot.errors import ErrorLevel +# Dialect names sqlglot recognizes. Used to drop unknown dialect names so a bad +# value never breaks parsing (sqlglot raises on an unknown dialect). +_KNOWN_SQLGLOT_DIALECTS: frozenset[str] = frozenset(d.value for d in Dialects) + +# SQLAlchemy ``dialect_name`` → sqlglot dialect mapping for names that differ. +_SQLALCHEMY_TO_SQLGLOT_DIALECT: dict[str, str] = { + "postgresql": "postgres", + "mssql": "tsql", +} + + +def resolve_sqlglot_dialect(dialect_name: str | None) -> str | None: + """ + Normalize a SQLAlchemy dialect name to a sqlglot dialect. + + Returns ``None`` (dialect-agnostic parsing) for empty, non-string, or + unknown inputs, so a bad dialect value never breaks SQL validation. + + :param dialect_name: A SQLAlchemy ``dialect_name`` (e.g. ``"postgresql"``). + :return: The matching sqlglot dialect (e.g. ``"postgres"``), or ``None``. + """ + if not isinstance(dialect_name, str) or not dialect_name: + return None + mapped = _SQLALCHEMY_TO_SQLGLOT_DIALECT.get(dialect_name, dialect_name) + return mapped if mapped in _KNOWN_SQLGLOT_DIALECTS else None + + # Allowlist: only these top-level statement types pass validation by default. # - Select: plain queries and CTE-wrapped queries (WITH ... AS ... SELECT is parsed # as Select with a `with` clause property — still a Select node at the top level) @@ -39,9 +67,21 @@ exp.Except, ) +# Read-only metadata statements that introspect the schema without touching data: +# - Describe: DESCRIBE / DESC (and EXPLAIN on some dialects) +# - Show: SHOW TABLES / SHOW COLUMNS / SHOW DATABASES, etc. +# Opt-in via ``allow_read_only_metadata=True``. SHOW only parses to ``exp.Show`` +# when a dialect that supports it is passed (e.g. snowflake, mysql); without a +# dialect sqlglot falls back to ``exp.Command``, which stays blocked. +READ_ONLY_METADATA_TYPES: tuple[type[exp.Expr], ...] = ( + exp.Describe, + exp.Show, +) + # Denylist: expression types that mutate data or schema when found anywhere in the AST. # This catches data-modifying CTEs (e.g. WITH del AS (DELETE …) SELECT …), -# SELECT INTO, and other constructs that bypass top-level type checks. +# SELECT INTO, DDL or DML wrapped behind DESCRIBE/EXPLAIN (e.g. DESCRIBE DROP TABLE …), +# and other constructs that bypass top-level type checks. # Note: exp.Command is sqlglot's fallback for any syntax it doesn't recognize. # Including it makes the denylist fail-closed (safer), but may block legitimate # vendor-specific SQL that sqlglot can't parse. Callers who need such syntax can @@ -53,6 +93,11 @@ exp.Merge, exp.Into, exp.Command, + # DDL — newly reachable through the DESCRIBE/SHOW allowlist, so deny it here too. + exp.Create, + exp.Drop, + exp.Alter, + exp.TruncateTable, ) @@ -66,6 +111,7 @@ def validate_sql( allowed_types: tuple[type[exp.Expr], ...] | None = None, dialect: str | None = None, allow_multiple_statements: bool = False, + allow_read_only_metadata: bool = False, ) -> list[exp.Expr]: """ Parse SQL and verify all statements are in the allowed types list. @@ -78,10 +124,16 @@ def validate_sql( :param sql: SQL string to validate. :param allowed_types: Tuple of sqlglot expression types to permit. - Defaults to ``(Select, Union, Intersect, Except)``. + Defaults to ``(Select, Union, Intersect, Except)``. When supplied, the + caller takes full control of the allow-list and ``allow_read_only_metadata`` + is ignored. :param dialect: SQL dialect for parsing (``postgres``, ``mysql``, etc.). :param allow_multiple_statements: Whether to allow multiple semicolon-separated statements. Default ``False``. + :param allow_read_only_metadata: Also permit read-only metadata statements + (``DESCRIBE``/``SHOW``) on top of the default read-only allow-list. Ignored + when ``allowed_types`` is supplied. Note ``SHOW`` only parses to a metadata + statement when a ``dialect`` that supports it is given. Default ``False``. :return: List of parsed sqlglot Expression objects. :raises SQLSafetyError: If the SQL is empty, contains disallowed statement types, or has multiple statements when not permitted. @@ -89,7 +141,18 @@ def validate_sql( if not sql or not sql.strip(): raise SQLSafetyError("Empty SQL input.") - types = allowed_types or DEFAULT_ALLOWED_TYPES + # A caller-supplied ``allowed_types`` is an explicit opt-out of the curated + # read-only defaults (and the data-modifying deep scan). Otherwise we use the + # read-only defaults, optionally widened with metadata statements, and keep + # the deep scan on. + if allowed_types is None: + types: tuple[type[exp.Expr], ...] = DEFAULT_ALLOWED_TYPES + if allow_read_only_metadata: + types = types + READ_ONLY_METADATA_TYPES + run_data_modifying_scan = True + else: + types = allowed_types + run_data_modifying_scan = types == DEFAULT_ALLOWED_TYPES try: statements = sqlglot.parse(sql, dialect=dialect, error_level=ErrorLevel.RAISE) @@ -114,10 +177,10 @@ def validate_sql( ) # Deep scan: reject data-modifying nodes hidden inside otherwise-allowed statements - # (e.g. data-modifying CTEs, SELECT INTO). Only applies when using the default - # read-only allowlist — callers who provide custom allowed_types have explicitly + # (e.g. data-modifying CTEs, SELECT INTO, EXPLAIN ). Runs for the curated + # read-only allow-list — callers who provide custom allowed_types have explicitly # opted into non-read-only operations. - if types is DEFAULT_ALLOWED_TYPES: + if run_data_modifying_scan: _check_for_data_modifying_nodes(parsed) return parsed diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py index c1aae15aad542..5e425597a32e1 100644 --- a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py +++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py @@ -494,3 +494,65 @@ def test_list_tables_deduplicates_same_table(self): result = json.loads(asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), tool=MagicMock()))) assert result == ["public.users"] + + +class TestSQLToolsetMetadataStatements: + """Read-only metadata statements (DESCRIBE/SHOW) flow through the query tool.""" + + def test_describe_allowed_through_query(self): + """DESCRIBE is read-only metadata and should not be rejected as unsafe.""" + ts = SQLToolset("pg_default") + ts._hook = _make_mock_db_hook( + records=[("id", "INTEGER"), ("name", "VARCHAR")], + last_description=[("column_name",), ("data_type",)], + ) + + result = asyncio.run( + ts.call_tool("query", {"sql": "DESCRIBE TABLE users"}, ctx=MagicMock(), tool=MagicMock()) + ) + data = json.loads(result) + assert "rows" in data + ts._hook.get_records.assert_called_once_with("DESCRIBE TABLE users") + + def test_show_allowed_with_snowflake_dialect(self): + """SHOW parses to a metadata statement once the hook's dialect is passed through.""" + ts = SQLToolset("sf_default") + ts._hook = _make_mock_db_hook(records=[("USERS",)], last_description=[("name",)]) + ts._hook.dialect_name = "snowflake" + + result = asyncio.run(ts.call_tool("query", {"sql": "SHOW TABLES"}, ctx=MagicMock(), tool=MagicMock())) + data = json.loads(result) + assert "rows" in data + ts._hook.get_records.assert_called_once_with("SHOW TABLES") + + @pytest.mark.parametrize( + "sql", + # SHOW falls back to Command on Postgres (no SHOW support); DELETE is a write. + ["SHOW TABLES", "DELETE FROM users"], + ids=["show_without_dialect_support", "write"], + ) + def test_query_blocks_disallowed_statements(self, sql): + ts = SQLToolset("pg_default") + ts._hook = _make_mock_db_hook() + ts._hook.dialect_name = "postgresql" + + with pytest.raises(SQLSafetyError, match="not allowed"): + asyncio.run(ts.call_tool("query", {"sql": sql}, ctx=MagicMock(), tool=MagicMock())) + + def test_check_query_accepts_describe(self): + ts = SQLToolset("pg_default") + ts._hook = _make_mock_db_hook() + + result = asyncio.run( + ts.call_tool("check_query", {"sql": "DESCRIBE TABLE users"}, ctx=MagicMock(), tool=MagicMock()) + ) + assert json.loads(result)["valid"] is True + + def test_check_query_handles_unresolvable_connection(self): + """check_query stays usable (dialect-agnostic) when the connection can't be resolved.""" + ts = SQLToolset("missing_conn") + with patch.object(ts, "_get_db_hook", side_effect=RuntimeError("no such connection")): + result = asyncio.run( + ts.call_tool("check_query", {"sql": "SELECT 1"}, ctx=MagicMock(), tool=MagicMock()) + ) + assert json.loads(result)["valid"] is True diff --git a/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py b/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py index fe27379aebec7..9ca6604ba5e61 100644 --- a/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py +++ b/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py @@ -19,7 +19,11 @@ import pytest from sqlglot import exp -from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError, validate_sql +from airflow.providers.common.ai.utils.sql_validation import ( + SQLSafetyError, + resolve_sqlglot_dialect, + validate_sql, +) class TestValidateSQLAllowed: @@ -213,3 +217,98 @@ def test_deep_scan_runs_with_explicit_default_types(self): allowed_types=DEFAULT_ALLOWED_TYPES, dialect="postgres", ) + + +class TestReadOnlyMetadata: + """Read-only metadata statements (DESCRIBE/SHOW) with ``allow_read_only_metadata``.""" + + @pytest.mark.parametrize( + ("sql", "kwargs"), + [ + ("DESCRIBE TABLE users", {}), + ("SHOW TABLES", {"dialect": "snowflake"}), + ], + ids=["describe", "show"], + ) + def test_metadata_blocked_without_flag(self, sql, kwargs): + with pytest.raises(SQLSafetyError, match="not allowed"): + validate_sql(sql, **kwargs) + + @pytest.mark.parametrize( + ("sql", "dialect", "expected_type"), + [ + # DESCRIBE/DESC parse to exp.Describe in every dialect (dialect-agnostic). + ("DESCRIBE TABLE users", None, exp.Describe), + ("DESC users", None, exp.Describe), + # SHOW only parses to exp.Show when a supporting dialect is passed. + ("SHOW TABLES", "snowflake", exp.Show), + ("SHOW COLUMNS IN users", "snowflake", exp.Show), + ], + ) + def test_metadata_allowed_with_flag(self, sql, dialect, expected_type): + result = validate_sql(sql, dialect=dialect, allow_read_only_metadata=True) + assert len(result) == 1 + assert isinstance(result[0], expected_type) + + def test_show_blocked_without_supporting_dialect(self): + """Without a dialect that supports SHOW, sqlglot falls back to exp.Command, still blocked.""" + with pytest.raises(SQLSafetyError, match="Command.*not allowed"): + validate_sql("SHOW TABLES", allow_read_only_metadata=True) + + def test_explain_wrapped_write_still_blocked(self): + """EXPLAIN parses to exp.Describe but the deep scan rejects the inner write.""" + with pytest.raises(SQLSafetyError, match="Data-modifying operation 'Delete'"): + validate_sql("EXPLAIN DELETE FROM users", dialect="mysql", allow_read_only_metadata=True) + + @pytest.mark.parametrize( + ("sql", "node"), + [ + ("DESCRIBE CREATE TABLE t (a int)", "Create"), + ("DESCRIBE DROP TABLE users", "Drop"), + ("DESCRIBE TRUNCATE TABLE users", "TruncateTable"), + ("DESCRIBE DELETE FROM users", "Delete"), + ], + ) + def test_describe_wrapped_ddl_or_dml_blocked(self, sql, node): + """DESCRIBE parses to exp.Describe; the deep scan rejects the inner write.""" + with pytest.raises(SQLSafetyError, match=f"Data-modifying operation '{node}'"): + validate_sql(sql, allow_read_only_metadata=True) + + def test_metadata_flag_ignored_when_custom_types_supplied(self): + """When the caller supplies allowed_types it controls the allow-list; the flag is ignored.""" + with pytest.raises(SQLSafetyError, match="Describe.*not allowed"): + validate_sql( + "DESCRIBE TABLE users", + allowed_types=(exp.Select,), + allow_read_only_metadata=True, + ) + + def test_select_still_allowed_with_flag(self): + result = validate_sql("SELECT 1", allow_read_only_metadata=True) + assert isinstance(result[0], exp.Select) + + def test_writes_still_blocked_with_flag(self): + with pytest.raises(SQLSafetyError, match="Delete.*not allowed"): + validate_sql("DELETE FROM users WHERE id = 1", allow_read_only_metadata=True) + + +class TestResolveSqlglotDialect: + """``resolve_sqlglot_dialect`` normalizes/validates SQLAlchemy dialect names.""" + + @pytest.mark.parametrize( + ("dialect_name", "expected"), + [ + ("postgresql", "postgres"), + ("mssql", "tsql"), + ("mysql", "mysql"), + ("snowflake", "snowflake"), + ("sqlite", "sqlite"), + (None, None), + ("", None), + ("default", None), + ("not_a_real_dialect", None), + (123, None), + ], + ) + def test_resolution(self, dialect_name, expected): + assert resolve_sqlglot_dialect(dialect_name) == expected