From 9a9e702264c73baa43ec09d2458917444ccd181b Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sat, 12 Jul 2025 22:22:16 +0200 Subject: [PATCH 1/3] Cleanup type ignores in snowflake provider where possible --- .../src/airflow/providers/snowflake/hooks/snowflake.py | 10 +++++----- .../providers/snowflake/hooks/snowflake_sql_api.py | 2 +- .../airflow/providers/snowflake/operators/snowflake.py | 2 +- .../airflow/providers/snowflake/utils/openlineage.py | 6 +++--- .../src/airflow/providers/snowflake/version_compat.py | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index ff55bfa21abae..0f80de9d6a04b 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -253,7 +253,7 @@ def _get_conn_params(self) -> dict[str, str | None]: This is used in ``get_uri()`` and ``get_connection()``. """ - conn = self.get_connection(self.snowflake_conn_id) # type: ignore[attr-defined] + conn = self.get_connection(self.snowflake_conn_id) extra_dict = conn.extra_dejson account = self._get_field(extra_dict, "account") or "" warehouse = self._get_field(extra_dict, "warehouse") or "" @@ -461,7 +461,7 @@ def set_autocommit(self, conn, autocommit: Any) -> None: def get_autocommit(self, conn): return getattr(conn, "autocommit_mode", False) - @overload # type: ignore[override] + @overload def run( self, sql: str | Iterable[str], @@ -544,16 +544,16 @@ def run( results = [] for sql_statement in sql_list: self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) - self._run_command(cur, sql_statement, parameters) # type: ignore[attr-defined] + self._run_command(cur, sql_statement, parameters) if handler is not None: - result = self._make_common_data_structure(handler(cur)) # type: ignore[attr-defined] + result = self._make_common_data_structure(handler(cur)) if return_single_query_results(sql, return_last, split_statements): _last_result = result _last_description = cur.description else: results.append(result) - self.descriptions.append(cur.description) # type: ignore[has-type] + self.descriptions.append(cur.description) query_id = cur.sfqid self.log.info("Rows affected: %s", cur.rowcount) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index f3b23ad3a585d..48747381602eb 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -464,7 +464,7 @@ async def _make_api_call_with_retries_async(self, method, url, headers, params=N :return: The response object from the API call. """ async with aiohttp.ClientSession(headers=headers) as session: - async for attempt in AsyncRetrying(**self.retry_config): # type: ignore + async for attempt in AsyncRetrying(**self.retry_config): with attempt: if method.upper() == "GET": async with session.request(method=method.lower(), url=url, params=params) as response: diff --git a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py index 48086e53cd8c7..4f214c681fb30 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py @@ -427,7 +427,7 @@ def execute(self, context: Context) -> None: """ self.log.info("Executing: %s", self.sql) self.query_ids = self._hook.execute_query( - self.sql, # type: ignore[arg-type] + self.sql, statement_count=self.statement_count, bindings=self.bindings, ) diff --git a/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py b/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py index 88c5caaa45e0b..9ae33ebd4c8ec 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py +++ b/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py @@ -251,7 +251,7 @@ def _get_queries_details_from_snowflake( try: # Can't import the SnowflakeSqlApiHook class and do proper isinstance check - circular imports if hook.__class__.__name__ == "SnowflakeSqlApiHook": - result = _run_single_query_with_api_hook(hook=hook, sql=query) # type: ignore[arg-type] + result = _run_single_query_with_api_hook(hook=hook, sql=query) result = _process_data_from_api(data=result) else: result = _run_single_query_with_hook(hook=hook, sql=query) @@ -426,8 +426,8 @@ def emit_openlineage_events_for_snowflake_queries( event_batch = _create_snowflake_event_pair( job_namespace=namespace(), job_name=f"{task_instance.dag_id}.{task_instance.task_id}.query.{counter}", - start_time=query_metadata.get("START_TIME", default_event_time), # type: ignore[arg-type] - end_time=query_metadata.get("END_TIME", default_event_time), # type: ignore[arg-type] + start_time=query_metadata.get("START_TIME", default_event_time), + end_time=query_metadata.get("END_TIME", default_event_time), # `EXECUTION_STATUS` can be `success`, `fail` or `incident` (Snowflake outage, so still failure) is_successful=query_metadata.get("EXECUTION_STATUS", default_state).lower() == "success", run_facets={**query_specific_run_facets, **common_run_facets, **additional_run_facets}, diff --git a/providers/snowflake/src/airflow/providers/snowflake/version_compat.py b/providers/snowflake/src/airflow/providers/snowflake/version_compat.py index df1b3e7037841..e7a259afb357c 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/version_compat.py +++ b/providers/snowflake/src/airflow/providers/snowflake/version_compat.py @@ -37,7 +37,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperator else: - from airflow.models import BaseOperator # type: ignore[no-redef] + from airflow.models import BaseOperator __all__ = [ "AIRFLOW_V_3_0_PLUS", From 3d7d6f5b9d553f2cf7c5d5c6778cf7dfa243304c Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sun, 13 Jul 2025 10:44:27 +0200 Subject: [PATCH 2/3] Get the hook the right way via get_conn_id() --- .../src/airflow/providers/snowflake/hooks/snowflake.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index 0f80de9d6a04b..fd7e28804c19d 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -253,7 +253,7 @@ def _get_conn_params(self) -> dict[str, str | None]: This is used in ``get_uri()`` and ``get_connection()``. """ - conn = self.get_connection(self.snowflake_conn_id) + conn = self.get_connection(self.get_conn_id()) extra_dict = conn.extra_dejson account = self._get_field(extra_dict, "account") or "" warehouse = self._get_field(extra_dict, "warehouse") or "" From 55db5d921c4af7169b521ec65fa94f59dfbc4793 Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sun, 13 Jul 2025 11:07:28 +0200 Subject: [PATCH 3/3] Make lazy import to prevent mypy complaints --- .../src/airflow/providers/snowflake/utils/openlineage.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py b/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py index 9ae33ebd4c8ec..d06a6c463e406 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py +++ b/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py @@ -249,8 +249,10 @@ def _get_queries_details_from_snowflake( ) try: - # Can't import the SnowflakeSqlApiHook class and do proper isinstance check - circular imports - if hook.__class__.__name__ == "SnowflakeSqlApiHook": + # Note: need to lazy import here to avoid circular imports + from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiHook + + if isinstance(hook, SnowflakeSqlApiHook): result = _run_single_query_with_api_hook(hook=hook, sql=query) result = _process_data_from_api(data=result) else: