diff --git a/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py index 8cfdbf753c27b..88799202654a5 100644 --- a/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py @@ -39,6 +39,7 @@ class SparkKubernetesSensor(BaseSensorOperator): :param application_name: spark Application resource name :param namespace: the kubernetes namespace where the sparkApplication reside in + :param container_name: the kubernetes container name where the sparkApplication reside in :param kubernetes_conn_id: The :ref:`kubernetes connection` to Kubernetes cluster. :param attach_log: determines whether logs for driver pod should be appended to the sensor log @@ -56,6 +57,7 @@ def __init__( application_name: str, attach_log: bool = False, namespace: str | None = None, + container_name: str = "spark-kubernetes-driver", kubernetes_conn_id: str = "kubernetes_default", api_group: str = 'sparkoperator.k8s.io', api_version: str = 'v1beta2', @@ -65,6 +67,7 @@ def __init__( self.application_name = application_name self.attach_log = attach_log self.namespace = namespace + self.container_name = container_name self.kubernetes_conn_id = kubernetes_conn_id self.hook = KubernetesHook(conn_id=self.kubernetes_conn_id) self.api_group = api_group @@ -84,7 +87,9 @@ def _log_driver(self, application_state: str, response: dict) -> None: log_method = self.log.error if application_state in self.FAILURE_STATES else self.log.info try: log = "" - for line in self.hook.get_pod_logs(driver_pod_name, namespace=namespace): + for line in self.hook.get_pod_logs( + driver_pod_name, namespace=namespace, container=self.container_name + ): log += line.decode() log_method(log) except client.rest.ApiException as e: diff --git a/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py index c8f6764cac3fb..4f69e73a58d87 100644 --- a/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py +++ b/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py @@ -483,6 +483,68 @@ }, } +TEST_DRIVER_WITH_SIDECAR_APPLICATION = { + "apiVersion": "sparkoperator.k8s.io/v1beta2", + "kind": "SparkApplication", + "metadata": { + "creationTimestamp": "2020-02-24T07:34:22Z", + "generation": 1, + "labels": {"spark_flow_name": "spark-pi"}, + "name": "spark-pi-2020-02-24-1", + "namespace": "default", + "resourceVersion": "455577", + "selfLink": "/apis/sparkoperator.k8s.io/v1beta2/namespaces/default/sparkapplications/spark-pi", + "uid": "9f825516-6e1a-4af1-8967-b05661e8fb08", + }, + "spec": { + "driver": { + "coreLimit": "1200m", + "cores": 1, + "labels": {"spark_flow_name": "spark-pi", "version": "2.4.4"}, + "memory": "512m", + "serviceAccount": "default", + "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], + "sidecars": [{"name": "sidecar1", "image": "hello-world:latest"}], + }, + "executor": { + "cores": 1, + "instances": 3, + "labels": {"spark_flow_name": "spark-pi", "version": "2.4.4"}, + "memory": "512m", + "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], + }, + "image": "gcr.io/spark-operator/spark:v2.4.4", + "imagePullPolicy": "Always", + "mainApplicationFile": "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.4.jar", + "mainClass": "org.apache.spark.examples.SparkPi", + "mode": "cluster", + "restartPolicy": {"type": "Never"}, + "sparkVersion": "2.4.4", + "type": "Scala", + "volumes": [{"hostPath": {"path": "/tmp", "type": "Directory"}, "name": "test-volume"}], + }, + "status": { + "applicationState": {"state": "COMPLETED"}, + "driverInfo": { + "podName": "spark-pi-2020-02-24-1-driver", + "webUIAddress": "10.97.130.44:4040", + "webUIPort": 4040, + "webUIServiceName": "spark-pi-2020-02-24-1-ui-svc", + }, + "executionAttempts": 1, + "executorState": { + "spark-pi-2020-02-24-1-1582529666227-exec-1": "FAILED", + "spark-pi-2020-02-24-1-1582529666227-exec-2": "FAILED", + "spark-pi-2020-02-24-1-1582529666227-exec-3": "FAILED", + }, + "lastSubmissionAttemptTime": "2020-02-24T07:34:30Z", + "sparkApplicationId": "spark-7bb432c422ca46f3854838c419460fec", + "submissionAttempts": 1, + "submissionID": "1a1f9c5e-6bdd-4824-806f-40a814c1cf43", + "terminationTime": "2020-02-24T07:35:01Z", + }, +} + TEST_POD_LOGS = [b"LOG LINE 1\n", b"LOG LINE 2"] TEST_POD_LOG_RESULT = "LOG LINE 1\nLOG LINE 2" @@ -726,7 +788,9 @@ def test_driver_logging_failure( ) with pytest.raises(AirflowException): sensor.poke(None) - mock_log_call.assert_called_once_with("spark-pi-driver", namespace="default") + mock_log_call.assert_called_once_with( + "spark-pi-driver", namespace="default", container='spark-kubernetes-driver' + ) error_log_call.assert_called_once_with(TEST_POD_LOG_RESULT) @patch( @@ -748,7 +812,9 @@ def test_driver_logging_completed( task_id="test_task_id", ) sensor.poke(None) - mock_log_call.assert_called_once_with("spark-pi-2020-02-24-1-driver", namespace="default") + mock_log_call.assert_called_once_with( + "spark-pi-2020-02-24-1-driver", namespace="default", container='spark-kubernetes-driver' + ) log_info_call = info_log_call.mock_calls[2] log_value = log_info_call[1][0] assert log_value == TEST_POD_LOG_RESULT @@ -773,3 +839,29 @@ def test_driver_logging_error( ) sensor.poke(None) warn_log_call.assert_called_once() + + @patch( + "kubernetes.client.api.custom_objects_api.CustomObjectsApi.get_namespaced_custom_object", + return_value=TEST_DRIVER_WITH_SIDECAR_APPLICATION, + ) + @patch("logging.Logger.info") + @patch( + "airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook.get_pod_logs", + return_value=TEST_POD_LOGS, + ) + def test_sidecar_driver_logging_completed( + self, mock_log_call, info_log_call, mock_get_namespaced_crd, mock_kube_conn + ): + sensor = SparkKubernetesSensor( + application_name="spark_pi", + attach_log=True, + dag=self.dag, + task_id="test_task_id", + ) + sensor.poke(None) + mock_log_call.assert_called_once_with( + "spark-pi-2020-02-24-1-driver", namespace="default", container='spark-kubernetes-driver' + ) + log_info_call = info_log_call.mock_calls[2] + log_value = log_info_call[1][0] + assert log_value == TEST_POD_LOG_RESULT