diff --git a/airflow/providers/apache/hdfs/hooks/hdfs.py b/airflow/providers/apache/hdfs/hooks/hdfs.py index 0e98320cb77d7..fda716327a022 100644 --- a/airflow/providers/apache/hdfs/hooks/hdfs.py +++ b/airflow/providers/apache/hdfs/hooks/hdfs.py @@ -51,7 +51,10 @@ class HDFSHook(BaseHook): hook_name = "HDFS" def __init__( - self, hdfs_conn_id: str = "hdfs_default", proxy_user: str | None = None, autoconfig: bool = False + self, + hdfs_conn_id: str | set[str] = "hdfs_default", + proxy_user: str | None = None, + autoconfig: bool = False, ): super().__init__() if not snakebite_loaded: @@ -60,7 +63,7 @@ def __init__( "snakebite is not compatible with Python 3 " "(as of August 2015). Please help by submitting a PR!" ) - self.hdfs_conn_id = hdfs_conn_id + self.hdfs_conn_id = {hdfs_conn_id} if isinstance(hdfs_conn_id, str) else hdfs_conn_id self.proxy_user = proxy_user self.autoconfig = autoconfig @@ -73,7 +76,7 @@ def get_conn(self) -> Any: use_sasl = conf.get("core", "security") == "kerberos" try: - connections = self.get_connections(self.hdfs_conn_id) + connections = [self.get_connection(i) for i in self.hdfs_conn_id] if not effective_user: effective_user = connections[0].login diff --git a/tests/providers/apache/hdfs/hooks/test_hdfs.py b/tests/providers/apache/hdfs/hooks/test_hdfs.py index 3b4f7e6d11225..bf724737a9604 100644 --- a/tests/providers/apache/hdfs/hooks/test_hdfs.py +++ b/tests/providers/apache/hdfs/hooks/test_hdfs.py @@ -74,10 +74,13 @@ def test_get_autoconfig_client_no_conn(self, mock_client): HDFSHook(hdfs_conn_id="hdfs_missing", autoconfig=True).get_conn() mock_client.assert_called_once_with(effective_user=None, use_sasl=False) - @mock.patch("airflow.providers.apache.hdfs.hooks.hdfs.HDFSHook.get_connections") - def test_get_ha_client(self, mock_get_connections): - conn_1 = Connection(conn_id="hdfs_default", conn_type="hdfs", host="localhost", port=8020) - conn_2 = Connection(conn_id="hdfs_default", conn_type="hdfs", host="localhost2", port=8020) - mock_get_connections.return_value = [conn_1, conn_2] - client = HDFSHook().get_conn() + @mock.patch.dict( + "os.environ", + { + "AIRFLOW_CONN_HDFS1": "hdfs://host1:8020", + "AIRFLOW_CONN_HDFS2": "hdfs://host2:8020", + }, + ) + def test_get_ha_client(self): + client = HDFSHook(hdfs_conn_id={"hdfs1", "hdfs2"}).get_conn() assert isinstance(client, snakebite.client.HAClient)