Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _get_conn_params(self) -> dict[str, str | None]:

This is used in ``get_uri()`` and ``get_connection()``.
"""
conn = self.get_connection(self.snowflake_conn_id) # type: ignore[attr-defined]
conn = self.get_connection(self.get_conn_id())
extra_dict = conn.extra_dejson
account = self._get_field(extra_dict, "account") or ""
warehouse = self._get_field(extra_dict, "warehouse") or ""
Expand Down Expand Up @@ -461,7 +461,7 @@ def set_autocommit(self, conn, autocommit: Any) -> None:
def get_autocommit(self, conn):
return getattr(conn, "autocommit_mode", False)

@overload # type: ignore[override]
@overload
def run(
self,
sql: str | Iterable[str],
Expand Down Expand Up @@ -544,16 +544,16 @@ def run(
results = []
for sql_statement in sql_list:
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
self._run_command(cur, sql_statement, parameters) # type: ignore[attr-defined]
self._run_command(cur, sql_statement, parameters)

if handler is not None:
result = self._make_common_data_structure(handler(cur)) # type: ignore[attr-defined]
result = self._make_common_data_structure(handler(cur))
if return_single_query_results(sql, return_last, split_statements):
_last_result = result
_last_description = cur.description
else:
results.append(result)
self.descriptions.append(cur.description) # type: ignore[has-type]
self.descriptions.append(cur.description)

query_id = cur.sfqid
self.log.info("Rows affected: %s", cur.rowcount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ async def _make_api_call_with_retries_async(self, method, url, headers, params=N
:return: The response object from the API call.
"""
async with aiohttp.ClientSession(headers=headers) as session:
async for attempt in AsyncRetrying(**self.retry_config): # type: ignore
async for attempt in AsyncRetrying(**self.retry_config):
with attempt:
if method.upper() == "GET":
async with session.request(method=method.lower(), url=url, params=params) as response:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def execute(self, context: Context) -> None:
"""
self.log.info("Executing: %s", self.sql)
self.query_ids = self._hook.execute_query(
self.sql, # type: ignore[arg-type]
self.sql,
statement_count=self.statement_count,
bindings=self.bindings,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,11 @@ def _get_queries_details_from_snowflake(
)

try:
# Can't import the SnowflakeSqlApiHook class and do proper isinstance check - circular imports
if hook.__class__.__name__ == "SnowflakeSqlApiHook":
result = _run_single_query_with_api_hook(hook=hook, sql=query) # type: ignore[arg-type]
# Note: need to lazy import here to avoid circular imports
from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiHook

if isinstance(hook, SnowflakeSqlApiHook):
result = _run_single_query_with_api_hook(hook=hook, sql=query)
result = _process_data_from_api(data=result)
else:
result = _run_single_query_with_hook(hook=hook, sql=query)
Expand Down Expand Up @@ -426,8 +428,8 @@ def emit_openlineage_events_for_snowflake_queries(
event_batch = _create_snowflake_event_pair(
job_namespace=namespace(),
job_name=f"{task_instance.dag_id}.{task_instance.task_id}.query.{counter}",
start_time=query_metadata.get("START_TIME", default_event_time), # type: ignore[arg-type]
end_time=query_metadata.get("END_TIME", default_event_time), # type: ignore[arg-type]
start_time=query_metadata.get("START_TIME", default_event_time),
end_time=query_metadata.get("END_TIME", default_event_time),
# `EXECUTION_STATUS` can be `success`, `fail` or `incident` (Snowflake outage, so still failure)
is_successful=query_metadata.get("EXECUTION_STATUS", default_state).lower() == "success",
run_facets={**query_specific_run_facets, **common_run_facets, **additional_run_facets},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperator
else:
from airflow.models import BaseOperator # type: ignore[no-redef]
from airflow.models import BaseOperator

__all__ = [
"AIRFLOW_V_3_0_PLUS",
Expand Down