diff --git a/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py b/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py index ffee8ad187018..7fa2d6434194c 100644 --- a/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py +++ b/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py @@ -30,7 +30,12 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: - from airflow.models import Connection + from airflow.providers.mysql.version_compat import AIRFLOW_V_3_0_PLUS + + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import Connection + else: + from airflow.models.connection import Connection # type: ignore[assignment] try: from mysql.connector.abstracts import MySQLConnectionAbstract @@ -130,7 +135,7 @@ def _get_conn_config_mysql_client(self, conn: Connection) -> dict: if conn.extra_dejson.get("charset", False): conn_config["charset"] = conn.extra_dejson["charset"] - if conn_config["charset"].lower() in ("utf8", "utf-8"): + if str(conn_config.get("charset", "undef")).lower() in ("utf8", "utf-8"): conn_config["use_unicode"] = True if conn.extra_dejson.get("cursor", False): try: @@ -220,7 +225,7 @@ def get_conn(self) -> MySQLConnectionTypes: "installed in case you see compilation error during installation." ) - conn_config = self._get_conn_config_mysql_client(conn) # type: ignore[arg-type] + conn_config = self._get_conn_config_mysql_client(conn) return MySQLdb.connect(**conn_config) if client_name == "mysql-connector-python": @@ -233,7 +238,7 @@ def get_conn(self) -> MySQLConnectionTypes: "'mysql-connector-python'. Warning! It might cause dependency conflicts." ) - conn_config = self._get_conn_config_mysql_connector_python(conn) # type: ignore[arg-type] + conn_config = self._get_conn_config_mysql_connector_python(conn) return mysql.connector.connect(**conn_config) raise ValueError("Unknown MySQL client name provided!") @@ -253,7 +258,7 @@ def bulk_load(self, table: str, tmp_file: str) -> None: (tmp_file,), ) conn.commit() - conn.close() # type: ignore[misc] + conn.close() def bulk_dump(self, table: str, tmp_file: str) -> None: """Dump a database table into a tab-delimited file.""" @@ -270,7 +275,7 @@ def bulk_dump(self, table: str, tmp_file: str) -> None: (tmp_file,), ) conn.commit() - conn.close() # type: ignore[misc] + conn.close() @staticmethod def _serialize_cell(cell: object, conn: Connection | None = None) -> Any: @@ -337,7 +342,7 @@ def bulk_load_custom( cursor.close() conn.commit() - conn.close() # type: ignore[misc] + conn.close() def get_openlineage_database_info(self, connection): """Return MySQL specific information for OpenLineage."""