diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/airflow/providers/cncf/kubernetes/utils/pod_manager.py index c65150b4181d4..a1fd511761d6e 100644 --- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -121,7 +121,16 @@ def __init__( :param cluster_context: context of the cluster """ super().__init__() - self._client = kube_client or get_kube_client(in_cluster=in_cluster, cluster_context=cluster_context) + if kube_client: + self._client = kube_client + else: + self._client = get_kube_client(in_cluster=in_cluster, cluster_context=cluster_context) + warnings.warn( + "`kube_client` not supplied to PodManager. " + "This will be a required argument in a future release. " + "Please use KubernetesHook to create the client before calling.", + DeprecationWarning, + ) self._watch = watch.Watch() def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod: diff --git a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py index 5299b6a4be778..9cb68b8897e5d 100644 --- a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py +++ b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py @@ -29,6 +29,7 @@ from airflow.exceptions import AirflowException from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager, PodPhase, container_is_running +from tests.test_utils.providers import get_provider_version, object_exists class TestPodManager: @@ -331,6 +332,20 @@ def test_fetch_container_running_follow( assert ret.last_log_time == DateTime(2021, 1, 1, tzinfo=Timezone('UTC')) assert ret.running is exp_running + def test_pod_manager_get_client_call_deprecation(self): + """Ensure that kube_client.get_kube_client is removed from pod manager in provider 6.0.""" + kube_client_path = 'airflow.providers.cncf.kubernetes.utils.pod_manager.get_kube_client' + if not object_exists(kube_client_path): + raise Exception( + "You must remove this test. It only exists to remind us to remove `get_kube_client`." + ) + + if get_provider_version('apache-airflow-providers-cncf-kubernetes') >= (6, 0): + raise Exception( + "You must now remove `get_kube_client` from PodManager " + "and make kube_client a required argument." + ) + def params_for_test_container_is_running(): """The `container_is_running` method is designed to handle an assortment of bad objects diff --git a/tests/test_utils/providers.py b/tests/test_utils/providers.py new file mode 100644 index 0000000000000..84a89bef81cd8 --- /dev/null +++ b/tests/test_utils/providers.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import semver + + +def object_exists(path: str): + """Returns true if importable python object is there.""" + from airflow.utils.module_loading import import_string + + try: + import_string(path) + return True + except ImportError: + return False + + +def get_provider_version(provider_name): + """ + Returns provider version given provider package name. + + Example:: + if provider_version('apache-airflow-providers-cncf-kubernetes') >= (6, 0): + raise Exception( + "You must now remove `get_kube_client` from PodManager " + "and make kube_client a required argument." + ) + """ + from airflow.providers_manager import ProvidersManager + + info = ProvidersManager().providers[provider_name] + return semver.VersionInfo.parse(info.version)