From a047627408c6008c615b486299458453af524ee9 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 29 May 2026 18:53:19 +0530 Subject: [PATCH 1/6] Use K8s API to track Spark on K8s instead of JVM based spark-submit --- .../apache/spark/hooks/spark_submit.py | 109 +++++++++++++++++- .../apache/spark/operators/spark_submit.py | 16 +++ .../apache/spark/hooks/test_spark_submit.py | 100 +++++++++++++++- .../spark/operators/test_spark_submit.py | 41 ++++++- 4 files changed, 259 insertions(+), 7 deletions(-) diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py index 9aa3ddc885efb..84b55148494df 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py @@ -41,6 +41,9 @@ DEFAULT_SPARK_BINARY = "spark-submit" ALLOWED_SPARK_BINARIES = [DEFAULT_SPARK_BINARY, "spark2-submit", "spark3-submit"] +_K8S_WAIT_APP_COMPLETION_CONF = "spark.kubernetes.submission.waitAppCompletion" +_K8S_DELETE_ON_TERMINATION_CONF = "spark.kubernetes.driver.deleteOnTermination" + class SparkSubmitHook(BaseHook, LoggingMixin): """ @@ -79,7 +82,10 @@ class SparkSubmitHook(BaseHook, LoggingMixin): :param name: Name of the job (default airflow-spark) :param num_executors: Number of executors to launch :param status_poll_interval: Seconds to wait between polls of driver status in cluster - mode (Default: 1) + mode. Used by the Spark standalone driver-status tracker and (when + ``track_driver_via_k8s_api=True``) by the Kubernetes API polling loop. + The Kubernetes polling loop enforces a 20-second minimum to avoid + excessive API server load on long-running jobs. (Default: 1) :param application_args: Arguments for the application being submitted :param env_vars: Environment variables for spark-submit. It supports yarn and k8s mode too. @@ -99,6 +105,13 @@ class SparkSubmitHook(BaseHook, LoggingMixin): job finishes (on both success and on_kill). Useful for cleaning up sidecars such as Istio (e.g. ``["curl -X POST localhost:15020/quitquitquit"]``). Each command is executed via the shell; failures produce a warning but do not fail the task. + :param track_driver_via_k8s_api: If True (when master is Kubernetes and + ``deploy_mode`` is ``cluster``), release the ``spark-submit`` JVM once the + driver pod has been created, then poll the Kubernetes API for the pod phase + until the application reaches a terminal state. The polling interval is + controlled by ``status_poll_interval`` with a 20-second minimum. This frees + the worker from holding the long-lived submit JVM (~500 MB). Defaults to + ``False``. """ conn_name_attr = "conn_id" @@ -207,6 +220,7 @@ def __init__( *, use_krb5ccache: bool = False, post_submit_commands: list[str] | None = None, + track_driver_via_k8s_api: bool = False, ) -> None: super().__init__() self._conf = conf or {} @@ -250,6 +264,7 @@ def __init__( f"{self._connection['master']} specified by kubernetes dependencies are not installed!" ) + self._track_driver_via_k8s_api = track_driver_via_k8s_api self._should_track_driver_status = self._resolve_should_track_driver_status() self._driver_id: str | None = None self._driver_status: str | None = None @@ -268,6 +283,29 @@ def _resolve_should_track_driver_status(self) -> bool: """ return "spark://" in self._connection["master"] and self._connection["deploy_mode"] == "cluster" + def _should_track_driver_via_k8s_api(self) -> bool: + return ( + self._track_driver_via_k8s_api + and self._is_kubernetes + and self._connection["deploy_mode"] == "cluster" + ) + + def _validate_track_driver_via_k8s_api_config(self) -> None: + if not self._is_kubernetes: + raise ValueError( + "`track_driver_via_k8s_api=True` requires Spark master to be Kubernetes (k8s://...)." + ) + if self._connection["deploy_mode"] != "cluster": + raise ValueError( + "`track_driver_via_k8s_api=True` requires `deploy_mode='cluster'`; " + f"got deploy_mode={self._connection['deploy_mode']!r}." + ) + if self._conf.get(_K8S_WAIT_APP_COMPLETION_CONF, "").lower() == "true": + raise ValueError( + f"`track_driver_via_k8s_api=True` is incompatible with " + f"`{_K8S_WAIT_APP_COMPLETION_CONF}=true`; remove it from your conf or set it to 'false'." + ) + def _resolve_connection(self) -> dict[str, Any]: # Build from connection master or default to yarn if not available conn_data: dict[str, Any] = { @@ -495,6 +533,10 @@ def _build_spark_common_args(self) -> list[str]: if self._connection["deploy_mode"]: args += ["--deploy-mode", self._connection["deploy_mode"]] + if self._should_track_driver_via_k8s_api(): + if _K8S_WAIT_APP_COMPLETION_CONF not in self._conf: + args += ["--conf", f"{_K8S_WAIT_APP_COMPLETION_CONF}=false"] + return args def _build_spark_submit_command(self, application: str) -> list[str]: @@ -632,8 +674,14 @@ def submit(self, application: str = "", **kwargs: Any) -> str | None: # Check spark-submit return code. In Kubernetes mode, also check the value # of exit code in the log, as it may differ. + # When polling via K8s API, spark-submit exits after pod creation (waitAppCompletion=false) + # so _spark_exit_code is never set by the JVM watcher — skip that check entirely. try: - if returncode or (self._is_kubernetes and self._spark_exit_code != 0): + if returncode or ( + self._is_kubernetes + and not self._should_track_driver_via_k8s_api() + and self._spark_exit_code != 0 + ): if self._is_kubernetes: raise AirflowException( f"Cannot execute: {self._mask_cmd(spark_submit_cmd)}. Error code is: {returncode}. " @@ -682,10 +730,17 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None: # If we run Kubernetes cluster mode, we want to extract the driver pod id # from the logs so we can kill the application when we stop it unexpectedly elif self._is_kubernetes: + # Two log formats exist across Spark versions: + # "pod name: -driver" and "submission ID spark:-driver" match_driver_pod = re.search(r"\s*pod name: ((.+?)-([a-z0-9]+)-driver$)", line) if match_driver_pod: self._kubernetes_driver_pod = match_driver_pod.group(1) self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod) + if not self._kubernetes_driver_pod: + match_submission_id = re.search(r"submission ID spark:(.+-driver)", line) + if match_submission_id: + self._kubernetes_driver_pod = match_submission_id.group(1) + self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod) match_application_id = re.search(r"\s*spark-app-selector -> (spark-([a-z0-9]+)), ", line) if match_application_id: @@ -802,6 +857,54 @@ def _start_driver_status_tracking(self) -> None: f"returncode = {returncode}" ) + def _poll_k8s_driver_via_api(self) -> None: + """Poll the K8s driver pod phase until it reaches a terminal state.""" + pod_name = self._kubernetes_driver_pod + namespace = self._connection["namespace"] + app_id = self._kubernetes_application_id or pod_name + + if not pod_name: + raise ValueError("K8s driver pod name not set; cannot poll status.") + + client = kube_client.get_kube_client(in_cluster=False) + poll_interval = max(self._status_poll_interval, 20) + # similar to `missed_job_status_reports` tolerance in `_start_driver_status_tracking`: + # tolerate transient `Unknown` phases (node temporarily unreachable) before giving up. + consecutive_unknown = 0 + max_consecutive_unknown = 3 + + while True: + pod = client.read_namespaced_pod(pod_name, namespace) + phase = pod.status.phase or "Initializing" + self.log.info("Application status for %s (phase: %s)", app_id, phase) + if phase == "Succeeded": + break + if phase == "Failed": + container_state = "" + if pod.status.container_statuses: + cs = pod.status.container_statuses[0] + if cs.state and cs.state.terminated: + container_state = ( + f" exit_code={cs.state.terminated.exit_code} reason={cs.state.terminated.reason}" + ) + raise RuntimeError(f"Spark application {app_id} failed (phase=Failed{container_state})") + if phase == "Unknown": + consecutive_unknown += 1 + if consecutive_unknown >= max_consecutive_unknown: + raise RuntimeError( + f"Spark application {app_id} reported Unknown phase " + f"{consecutive_unknown} times consecutively; giving up." + ) + else: + consecutive_unknown = 0 + time.sleep(poll_interval) + self._run_post_submit_commands() + try: + client.delete_namespaced_pod(pod_name, namespace) + self.log.info("Deleted driver pod %s", pod_name) + except kube_client.ApiException: + self.log.warning("Could not delete driver pod %s after completion", pod_name) + def _build_spark_driver_kill_command(self) -> list[str]: """ Construct the spark-submit command to kill a driver. @@ -865,7 +968,7 @@ def on_kill(self) -> None: try: import kubernetes - client = kube_client.get_kube_client() + client = kube_client.get_kube_client(in_cluster=False) api_response = client.delete_namespaced_pod( self._kubernetes_driver_pod, self._connection["namespace"], diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index 76b010107da1d..cf295c3685599 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -108,6 +108,12 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator): on keytab for Kerberos login :param post_submit_commands: Optional list of shell commands to run after the Spark job finishes. Useful for cleaning up sidecars such as Istio. Failures produce a warning but do not fail the task. + :param track_driver_via_k8s_api: If True (when master is Kubernetes and ``deploy_mode`` + is ``cluster``), release the ``spark-submit`` JVM once the driver pod has been + created, then poll the Kubernetes API for the pod phase until the application + reaches a terminal state. The polling interval is controlled by + ``status_poll_interval`` with a 20-second minimum. This frees the worker from + holding the long-lived submit JVM. Defaults to ``False``. """ # Generic key used across all Spark deployment modes (standalone driver ID, @@ -168,6 +174,7 @@ def __init__( use_krb5ccache: bool = False, post_submit_commands: list[str] | None = None, reconnect_on_retry: bool = True, + track_driver_via_k8s_api: bool = False, openlineage_inject_parent_job_info: bool = conf.getboolean( "openlineage", "spark_inject_parent_job_info", fallback=False ), @@ -212,6 +219,7 @@ def __init__( self._use_krb5ccache = use_krb5ccache self.reconnect_on_retry = reconnect_on_retry + self._track_driver_via_k8s_api = track_driver_via_k8s_api self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info self._openlineage_inject_transport_info = openlineage_inject_transport_info @@ -234,6 +242,13 @@ def execute(self, context: Context) -> None: driver_id = self.submit_job(context) self.poll_until_complete(driver_id, context) return self.get_job_result(driver_id, context) + if hook._should_track_driver_via_k8s_api(): + hook._validate_track_driver_via_k8s_api_config() + # TODO: Wire into execute_resumable() via ResumableJobMixin + # (fill submit_job / poll_until_complete K8s stubs) to enable crash recovery. + hook.submit(self.application) + hook._poll_k8s_driver_via_api() + return hook.submit(self.application) def submit_job(self, context: Context) -> str: @@ -378,4 +393,5 @@ def _get_hook(self) -> SparkSubmitHook: deploy_mode=self._deploy_mode, use_krb5ccache=self._use_krb5ccache, post_submit_commands=self.post_submit_commands, + track_driver_via_k8s_api=self._track_driver_via_k8s_api, ) diff --git a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py index c909e9f12ab55..830b6bf18bbb6 100644 --- a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py @@ -21,9 +21,10 @@ import os from io import StringIO from pathlib import Path -from unittest.mock import call, mock_open, patch +from unittest.mock import MagicMock, call, mock_open, patch import pytest +from kubernetes.client import V1Pod, V1PodStatus from airflow.models import Connection from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook @@ -905,6 +906,18 @@ def test_process_spark_submit_log_k8s_spark_3(self): # Then assert hook._spark_exit_code == 999 + def test_process_spark_submit_log_k8s_submission_id_format(self): + hook = SparkSubmitHook(conn_id="spark_k8s_cluster") + log_lines = [ + "INFO Client: Deployed Spark application arrow-spark with application ID " + "spark-1e22d65826b74ac2927249b0e607ed54 and submission ID " + "spark:arrow-spark-c8e2e29e73db9c93-driver into Kubernetes", + ] + + hook._process_spark_submit_log(log_lines) + + assert hook._kubernetes_driver_pod == "arrow-spark-c8e2e29e73db9c93-driver" + def test_process_spark_client_mode_submit_log_k8s(self): # Given hook = SparkSubmitHook(conn_id="spark_k8s_client") @@ -1302,7 +1315,6 @@ def test_create_keytab_path_from_base64_keytab_with_existing_keytab( def test_run_post_submit_commands_success(self, mock_run): """Test that post_submit_commands are run with shell=False and shlex.split.""" import subprocess - from unittest.mock import MagicMock mock_result = MagicMock(spec=subprocess.CompletedProcess) mock_result.returncode = 0 @@ -1330,7 +1342,6 @@ def test_run_post_submit_commands_success(self, mock_run): def test_run_post_submit_commands_nonzero_exit_warns(self, mock_run): """Test that a non-zero exit code logs a warning but does not raise.""" import subprocess - from unittest.mock import MagicMock mock_result = MagicMock(spec=subprocess.CompletedProcess) mock_result.returncode = 1 @@ -1366,3 +1377,86 @@ def test_post_submit_commands_none_gives_empty_list(self): """Test that None post_submit_commands results in an empty list.""" hook = SparkSubmitHook(conn_id="") assert hook._post_submit_commands == [] + + @pytest.mark.parametrize( + ("conn_id", "flag", "expected"), + [ + ("spark_k8s_cluster", False, False), + ("spark_k8s_cluster", True, True), + ("spark_k8s_client", True, False), + ], + ) + def test_should_track_driver_via_k8s_api(self, conn_id, flag, expected): + hook = SparkSubmitHook(conn_id=conn_id, track_driver_via_k8s_api=flag) + assert hook._should_track_driver_via_k8s_api() is expected + + @pytest.mark.parametrize( + ("conn_id", "match"), + [ + ("spark_yarn_cluster", "requires Spark master to be Kubernetes"), + ("spark_k8s_client", "requires `deploy_mode='cluster'`"), + ], + ) + def test_validate_track_driver_via_k8s_api_raises(self, conn_id, match): + hook = SparkSubmitHook(conn_id=conn_id, track_driver_via_k8s_api=True) + with pytest.raises(ValueError, match=match): + hook._validate_track_driver_via_k8s_api_config() + + def test_validate_track_driver_via_k8s_api_raises_on_conflicting_user_conf(self): + hook = SparkSubmitHook( + conn_id="spark_k8s_cluster", + track_driver_via_k8s_api=True, + conf={"spark.kubernetes.submission.waitAppCompletion": "true"}, + ) + with pytest.raises(ValueError, match="incompatible with.*waitAppCompletion=true"): + hook._validate_track_driver_via_k8s_api_config() + + def test_conf_injection_adds_wait_app_completion(self): + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + cmd = hook._build_spark_submit_command("app.jar") + conf_pairs = [cmd[i + 1] for i, v in enumerate(cmd) if v == "--conf"] + assert "spark.kubernetes.submission.waitAppCompletion=false" in conf_pairs + + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_poll_k8s_driver_succeeds(self, mock_get_client): + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + + mock_client = mock_get_client.return_value + running_pod = V1Pod(status=V1PodStatus(phase="Running")) + succeeded_pod = V1Pod(status=V1PodStatus(phase="Succeeded")) + mock_client.read_namespaced_pod.side_effect = [running_pod, succeeded_pod] + + with patch.object(hook, "_run_post_submit_commands"): + hook._poll_k8s_driver_via_api() + + mock_client.delete_namespaced_pod.assert_called_once_with("spark-app-abc-driver", "mynamespace") + + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_poll_k8s_driver_raises_on_failed(self, mock_get_client): + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + + mock_client = mock_get_client.return_value + failed_pod = V1Pod(status=V1PodStatus(phase="Failed")) + mock_client.read_namespaced_pod.return_value = failed_pod + + with pytest.raises(RuntimeError, match="phase=Failed"): + hook._poll_k8s_driver_via_api() + + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_poll_k8s_driver_raises_after_consecutive_unknown(self, mock_get_client): + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + + mock_client = mock_get_client.return_value + mock_client.read_namespaced_pod.return_value = V1Pod(status=V1PodStatus(phase="Unknown")) + + with patch("time.sleep"), pytest.raises(RuntimeError, match="Unknown phase"): + hook._poll_k8s_driver_via_api() + + # assert that it was polled minimum 3 times to confirm the Unknown status before raising + assert mock_client.read_namespaced_pod.call_count == 3 diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 65af1116861bf..1850d5c53ad66 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -733,4 +733,43 @@ def simulate_failed_tracking(): with pytest.raises(RuntimeError, match="FAILED"): operator.poll_until_complete("driver-001", {}) - assert post_submit_called, "_run_post_submit_commands must be called even on driver failure" + +class TestSparkSubmitOperatorK8sTracking: + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + self.dag = DAG("test_k8s_tracking_dag", schedule=None, default_args=args) + + def _make_operator(self, **kwargs): + return SparkSubmitOperator(task_id="test", dag=self.dag, application="test.jar", **kwargs) + + def _make_k8s_hook(self): + hook = MagicMock() + hook._should_track_driver_status = False + hook._should_track_driver_via_k8s_api.return_value = True + return hook + + def test_execute_calls_submit_then_poll_when_flag_set(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + hook = self._make_k8s_hook() + operator._hook = hook + call_order = [] + hook.submit.side_effect = lambda *a, **kw: call_order.append("submit") + hook._poll_k8s_driver_via_api.side_effect = lambda: call_order.append("poll") + + operator.execute(context={}) + + hook.submit.assert_called_once_with("test.jar") + hook._poll_k8s_driver_via_api.assert_called_once() + assert call_order == ["submit", "poll"] + + def test_execute_falls_through_to_plain_submit_when_flag_off(self): + operator = self._make_operator(track_driver_via_k8s_api=False) + hook = MagicMock() + hook._should_track_driver_status = False + hook._should_track_driver_via_k8s_api.return_value = False + operator._hook = hook + + operator.execute(context={}) + + hook.submit.assert_called_once_with("test.jar") + hook._poll_k8s_driver_via_api.assert_not_called() From 1127e10bc6a2cfcb83506084fa55f62b9e49ffe8 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 29 May 2026 19:01:28 +0530 Subject: [PATCH 2/6] Use K8s API to track Spark on K8s instead of JVM based spark-submit --- providers/apache/spark/docs/operators.rst | 29 +++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/providers/apache/spark/docs/operators.rst b/providers/apache/spark/docs/operators.rst index 64af53454f42d..916c8d75d0df8 100644 --- a/providers/apache/spark/docs/operators.rst +++ b/providers/apache/spark/docs/operators.rst @@ -214,3 +214,32 @@ See :doc:`connections/spark-submit` for how to configure these fields. .. note:: Crash recovery in cluster mode requires Airflow 3.3+ (``task_state`` support). On earlier versions the operator falls back to the previous behavior of always submitting fresh. + +Tracking driver status via Kubernetes API +"""""""""""""""""""""""""""""""""""""""""" + +When running in Kubernetes cluster mode, ``spark-submit`` blocks for the duration of the job. +The JVM runs processes which does nothing but polling of the pod phase and holds heap space for +the entire duration. This is not ideal for long-running jobs, especially when the driver is idle +for long periods (e.g. waiting for data or user input). + +Set ``track_driver_via_k8s_api=True`` to have the operator track the driver pod status via the +Python Kubernetes client rather than holding ``spark-submit`` open for the full job duration: + +.. code-block:: python + + from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator + + run_spark = SparkSubmitOperator( + task_id="run_spark", + application="local:///opt/spark/examples/jars/spark-examples.jar", + conn_id="spark_k8s", + deploy_mode="cluster", + track_driver_via_k8s_api=True, + ) + +**Requirements** + +* The Spark connection ``master`` must be ``k8s://...`` and ``deploy_mode`` must be ``cluster``. +* Do not set ``spark.kubernetes.submission.waitAppCompletion=true`` in your ``conf`` — this + conflicts with the flag and a ``ValueError`` will be raised at task start. From e42f28a08970451bfbd4f4444edb158e636588e4 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Sat, 30 May 2026 09:06:16 +0530 Subject: [PATCH 3/6] review comments from kaxil --- .../apache/spark/hooks/spark_submit.py | 117 ++++++++++++------ .../apache/spark/operators/spark_submit.py | 3 +- .../apache/spark/hooks/test_spark_submit.py | 43 +++++++ 3 files changed, 126 insertions(+), 37 deletions(-) diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py index 84b55148494df..25a0cc7d35fad 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py @@ -42,7 +42,6 @@ ALLOWED_SPARK_BINARIES = [DEFAULT_SPARK_BINARY, "spark2-submit", "spark3-submit"] _K8S_WAIT_APP_COMPLETION_CONF = "spark.kubernetes.submission.waitAppCompletion" -_K8S_DELETE_ON_TERMINATION_CONF = "spark.kubernetes.driver.deleteOnTermination" class SparkSubmitHook(BaseHook, LoggingMixin): @@ -300,6 +299,11 @@ def _validate_track_driver_via_k8s_api_config(self) -> None: "`track_driver_via_k8s_api=True` requires `deploy_mode='cluster'`; " f"got deploy_mode={self._connection['deploy_mode']!r}." ) + if not self._connection.get("namespace"): + raise ValueError( + "`track_driver_via_k8s_api=True` requires a namespace; " + "set it in the connection extra as `namespace` or via `spark.kubernetes.namespace` in conf." + ) if self._conf.get(_K8S_WAIT_APP_COMPLETION_CONF, "").lower() == "true": raise ValueError( f"`track_driver_via_k8s_api=True` is incompatible with " @@ -737,7 +741,7 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None: self._kubernetes_driver_pod = match_driver_pod.group(1) self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod) if not self._kubernetes_driver_pod: - match_submission_id = re.search(r"submission ID spark:(.+-driver)", line) + match_submission_id = re.search(r"submission ID spark:(.+?-driver)", line) if match_submission_id: self._kubernetes_driver_pod = match_submission_id.group(1) self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod) @@ -866,44 +870,85 @@ def _poll_k8s_driver_via_api(self) -> None: if not pod_name: raise ValueError("K8s driver pod name not set; cannot poll status.") - client = kube_client.get_kube_client(in_cluster=False) + client = kube_client.get_kube_client() poll_interval = max(self._status_poll_interval, 20) - # similar to `missed_job_status_reports` tolerance in `_start_driver_status_tracking`: - # tolerate transient `Unknown` phases (node temporarily unreachable) before giving up. + if poll_interval != self._status_poll_interval: + self.log.info( + "status_poll_interval=%ds is below the 20s minimum for K8s API polling; using 20s.", + self._status_poll_interval, + ) + # Mirror `missed_job_status_reports` / `max_missed_job_status_reports` from + # `_start_driver_status_tracking`: tolerate transient failures before giving up. consecutive_unknown = 0 max_consecutive_unknown = 3 + consecutive_api_errors = 0 + max_consecutive_api_errors = 3 + consecutive_pending = 0 + pending_warn_threshold = 10 - while True: - pod = client.read_namespaced_pod(pod_name, namespace) - phase = pod.status.phase or "Initializing" - self.log.info("Application status for %s (phase: %s)", app_id, phase) - if phase == "Succeeded": - break - if phase == "Failed": - container_state = "" - if pod.status.container_statuses: - cs = pod.status.container_statuses[0] - if cs.state and cs.state.terminated: - container_state = ( - f" exit_code={cs.state.terminated.exit_code} reason={cs.state.terminated.reason}" - ) - raise RuntimeError(f"Spark application {app_id} failed (phase=Failed{container_state})") - if phase == "Unknown": - consecutive_unknown += 1 - if consecutive_unknown >= max_consecutive_unknown: - raise RuntimeError( - f"Spark application {app_id} reported Unknown phase " - f"{consecutive_unknown} times consecutively; giving up." - ) - else: - consecutive_unknown = 0 - time.sleep(poll_interval) - self._run_post_submit_commands() try: - client.delete_namespaced_pod(pod_name, namespace) - self.log.info("Deleted driver pod %s", pod_name) - except kube_client.ApiException: - self.log.warning("Could not delete driver pod %s after completion", pod_name) + while True: + try: + pod = client.read_namespaced_pod(pod_name, namespace) + consecutive_api_errors = 0 + except kube_client.ApiException as e: + consecutive_api_errors += 1 + self.log.warning( + "ApiException polling pod %s (%d/%d): %s", + pod_name, + consecutive_api_errors, + max_consecutive_api_errors, + e, + ) + if consecutive_api_errors >= max_consecutive_api_errors: + raise RuntimeError( + f"K8s API unreachable after {consecutive_api_errors} consecutive errors " + f"while polling {app_id}; giving up." + ) from e + time.sleep(poll_interval) + continue + + phase = pod.status.phase or "Initializing" + self.log.info("Application status for %s (phase: %s)", app_id, phase) + if phase == "Succeeded": + break + if phase == "Failed": + container_state = "" + if pod.status.container_statuses: + cs = pod.status.container_statuses[0] + if cs.state and cs.state.terminated: + container_state = f" exit_code={cs.state.terminated.exit_code} reason={cs.state.terminated.reason}" + raise RuntimeError(f"Spark application {app_id} failed (phase=Failed{container_state})") + if phase == "Pending": + consecutive_pending += 1 + if consecutive_pending == pending_warn_threshold: + self.log.warning( + "Driver pod %s has been Pending for %d polls (~%ds); " + "it may be unschedulable. Continuing to wait — set execution_timeout to bound wait time.", + pod_name, + consecutive_pending, + consecutive_pending * poll_interval, + ) + else: + consecutive_pending = 0 + + if phase == "Unknown": + consecutive_unknown += 1 + if consecutive_unknown >= max_consecutive_unknown: + raise RuntimeError( + f"Spark application {app_id} reported Unknown phase " + f"{consecutive_unknown} times consecutively; giving up." + ) + else: + consecutive_unknown = 0 + time.sleep(poll_interval) + try: + client.delete_namespaced_pod(pod_name, namespace) + self.log.info("Deleted driver pod %s", pod_name) + except kube_client.ApiException: + self.log.warning("Could not delete driver pod %s after completion", pod_name) + finally: + self._run_post_submit_commands() def _build_spark_driver_kill_command(self) -> list[str]: """ @@ -968,7 +1013,7 @@ def on_kill(self) -> None: try: import kubernetes - client = kube_client.get_kube_client(in_cluster=False) + client = kube_client.get_kube_client() api_response = client.delete_namespaced_pod( self._kubernetes_driver_pod, self._connection["namespace"], diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index cf295c3685599..6ae19ec4ff6d0 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -235,6 +235,8 @@ def execute(self, context: Context) -> None: if self._hook is None: self._hook = self._get_hook() hook = self._hook + if self._track_driver_via_k8s_api: + hook._validate_track_driver_via_k8s_api_config() if hook._should_track_driver_status: if self.reconnect_on_retry: return self.execute_resumable(context) @@ -243,7 +245,6 @@ def execute(self, context: Context) -> None: self.poll_until_complete(driver_id, context) return self.get_job_result(driver_id, context) if hook._should_track_driver_via_k8s_api(): - hook._validate_track_driver_via_k8s_api_config() # TODO: Wire into execute_resumable() via ResumableJobMixin # (fill submit_job / poll_until_complete K8s stubs) to enable crash recovery. hook.submit(self.application) diff --git a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py index 830b6bf18bbb6..8befc4191bb8c 100644 --- a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py @@ -28,6 +28,7 @@ from airflow.models import Connection from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook +from airflow.providers.cncf.kubernetes import kube_client from airflow.providers.common.compat.sdk import AirflowException @@ -100,6 +101,14 @@ def setup_connections(self, create_connection_without_db): extra='{"deploy-mode": "client", "namespace": "mynamespace"}', ) ) + create_connection_without_db( + Connection( + conn_id="spark_k8s_cluster_no_namespace", + conn_type="spark", + host="k8s://https://k8s-master", + extra='{"deploy-mode": "cluster"}', + ) + ) create_connection_without_db( Connection(conn_id="spark_default_mesos", conn_type="spark", host="mesos://host", port=5050) ) @@ -1395,6 +1404,7 @@ def test_should_track_driver_via_k8s_api(self, conn_id, flag, expected): [ ("spark_yarn_cluster", "requires Spark master to be Kubernetes"), ("spark_k8s_client", "requires `deploy_mode='cluster'`"), + ("spark_k8s_cluster_no_namespace", "requires a namespace"), ], ) def test_validate_track_driver_via_k8s_api_raises(self, conn_id, match): @@ -1460,3 +1470,36 @@ def test_poll_k8s_driver_raises_after_consecutive_unknown(self, mock_get_client) # assert that it was polled minimum 3 times to confirm the Unknown status before raising assert mock_client.read_namespaced_pod.call_count == 3 + + @patch("time.sleep") + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_poll_k8s_driver_tolerates_transient_api_errors(self, mock_get_client, _): + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + + mock_client = mock_get_client.return_value + api_error = kube_client.ApiException(status=500, reason="Internal Server Error") + succeeded_pod = V1Pod(status=V1PodStatus(phase="Succeeded")) + mock_client.read_namespaced_pod.side_effect = [api_error, api_error, succeeded_pod] + + with patch.object(hook, "_run_post_submit_commands"): + hook._poll_k8s_driver_via_api() + + assert mock_client.read_namespaced_pod.call_count == 3 + + @patch("time.sleep") + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_poll_k8s_driver_raises_after_consecutive_api_errors(self, mock_get_client, _): + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + + mock_client = mock_get_client.return_value + api_error = kube_client.ApiException(status=500, reason="Internal Server Error") + mock_client.read_namespaced_pod.side_effect = api_error + + with pytest.raises(RuntimeError, match="K8s API unreachable"): + hook._poll_k8s_driver_via_api() + + assert mock_client.read_namespaced_pod.call_count == 3 From d12200d5f33ea3174208900ca8011f7e4649bd6e Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 2 Jun 2026 10:31:08 +0530 Subject: [PATCH 4/6] comments from kaxil --- providers/apache/spark/docs/operators.rst | 4 ++ .../apache/spark/hooks/spark_submit.py | 54 +++++++++++-------- .../apache/spark/hooks/test_spark_submit.py | 39 ++++++++++++++ 3 files changed, 74 insertions(+), 23 deletions(-) diff --git a/providers/apache/spark/docs/operators.rst b/providers/apache/spark/docs/operators.rst index 916c8d75d0df8..70bc9e0fd181b 100644 --- a/providers/apache/spark/docs/operators.rst +++ b/providers/apache/spark/docs/operators.rst @@ -243,3 +243,7 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full * The Spark connection ``master`` must be ``k8s://...`` and ``deploy_mode`` must be ``cluster``. * Do not set ``spark.kubernetes.submission.waitAppCompletion=true`` in your ``conf`` — this conflicts with the flag and a ``ValueError`` will be raised at task start. +* The Airflow worker must be able to reach the Kubernetes API server and have permission to + read and delete pods in the driver's namespace; otherwise pod tracking and cleanup will fail. +* This path bypasses ``ResumableJobMixin``, so Airflow retries submit a fresh driver instead of + reconnecting to an existing one. Set ``execution_timeout`` to bound wall-clock time. diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py index 25a0cc7d35fad..bd2ba064d0d97 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py @@ -304,7 +304,7 @@ def _validate_track_driver_via_k8s_api_config(self) -> None: "`track_driver_via_k8s_api=True` requires a namespace; " "set it in the connection extra as `namespace` or via `spark.kubernetes.namespace` in conf." ) - if self._conf.get(_K8S_WAIT_APP_COMPLETION_CONF, "").lower() == "true": + if str(self._conf.get(_K8S_WAIT_APP_COMPLETION_CONF, "")).lower() == "true": raise ValueError( f"`track_driver_via_k8s_api=True` is incompatible with " f"`{_K8S_WAIT_APP_COMPLETION_CONF}=true`; remove it from your conf or set it to 'false'." @@ -700,10 +700,11 @@ def submit(self, application: str = "", **kwargs: Any) -> str | None: "No driver id is known: something went wrong when executing the spark submit command" ) finally: - # In cluster mode with driver tracking, the operator calls poll_until_complete - # after submit() returns, so post_submit_commands are deferred there to preserve - # the "runs after job finishes" contract. In all other modes, run them here. - if not self._should_track_driver_status: + # K8s-API tracking defers post-submit commands to _poll_k8s_driver_via_api's finally + # block so they run once after the driver reaches a terminal state. Spark cluster-mode + # driver tracking defers them to poll_until_complete for the same reason. All other + # modes run them here, immediately after spark-submit exits. + if not self._should_track_driver_status and not self._should_track_driver_via_k8s_api(): self._run_post_submit_commands() return self._driver_id @@ -970,6 +971,25 @@ def _build_spark_driver_kill_command(self) -> list[str]: return connection_cmd + def _delete_driver_pod(self) -> None: + """Delete the Kubernetes driver pod, logging a warning on failure.""" + import kubernetes + + self.log.info("Deleting driver pod %s on Kubernetes", self._kubernetes_driver_pod) + try: + client = kube_client.get_kube_client() + client.delete_namespaced_pod( + self._kubernetes_driver_pod, + self._connection["namespace"], + body=kubernetes.client.V1DeleteOptions(), + pretty=True, + ) + self.log.info("Deleted driver pod %s", self._kubernetes_driver_pod) + except kube_client.ApiException: + self.log.exception( + "Exception when attempting to delete driver pod %s", self._kubernetes_driver_pod + ) + def on_kill(self) -> None: """Kill Spark submit command.""" self.log.debug("Kill Command is being called") @@ -983,6 +1003,11 @@ def on_kill(self) -> None: "Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait() ) + if self._should_track_driver_via_k8s_api() and self._kubernetes_driver_pod: + # spark-submit exits early under waitAppCompletion=false, so _submit_sp.poll() is + # not None during the poll loop — the deletion block below is skipped on kill. + self._delete_driver_pod() + if self._submit_sp and self._submit_sp.poll() is None: self.log.info("Sending kill signal to %s", self._connection["spark_binary"]) self._submit_sp.kill() @@ -1007,23 +1032,6 @@ def on_kill(self) -> None: self.log.info("YARN app killed with return code: %s", yarn_kill.wait()) if self._kubernetes_driver_pod: - self.log.info("Killing pod %s on Kubernetes", self._kubernetes_driver_pod) - - # Currently only instantiate Kubernetes client for killing a spark pod. - try: - import kubernetes - - client = kube_client.get_kube_client() - api_response = client.delete_namespaced_pod( - self._kubernetes_driver_pod, - self._connection["namespace"], - body=kubernetes.client.V1DeleteOptions(), - pretty=True, - ) - - self.log.info("Spark on K8s killed with response: %s", api_response) - - except kube_client.ApiException: - self.log.exception("Exception when attempting to kill Spark on K8s") + self._delete_driver_pod() self._run_post_submit_commands() diff --git a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py index 8befc4191bb8c..96337a1aa6265 100644 --- a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py @@ -23,6 +23,7 @@ from pathlib import Path from unittest.mock import MagicMock, call, mock_open, patch +import kubernetes import pytest from kubernetes.client import V1Pod, V1PodStatus @@ -1123,6 +1124,29 @@ def test_k8s_process_on_kill(self, mock_popen, mock_client_method): "spark-pi-edf2ace37be7353a958b38733a12f8e6-driver", "mynamespace", **kwargs ) + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_on_kill_deletes_pod_when_k8s_api_tracking_and_submit_sp_already_exited(self, mock_get_client): + """on_kill must delete the driver pod when K8s-API tracking is active even if spark-submit + has already exited. + """ + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + hook._submit_sp = MagicMock() + # spark-submit already exited + hook._submit_sp.poll.return_value = 0 + + mock_client = mock_get_client.return_value + + hook.on_kill() + + mock_client.delete_namespaced_pod.assert_called_once_with( + "spark-app-abc-driver", + "mynamespace", + body=kubernetes.client.V1DeleteOptions(), + pretty=True, + ) + @pytest.mark.parametrize( ("command", "expected"), [ @@ -1488,6 +1512,21 @@ def test_poll_k8s_driver_tolerates_transient_api_errors(self, mock_get_client, _ assert mock_client.read_namespaced_pod.call_count == 3 + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_post_submit_commands_run_exactly_once_on_k8s_path(self, mock_get_client): + """_run_post_submit_commands must fire exactly once: in _poll_k8s_driver_via_api finally.""" + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + + mock_client = mock_get_client.return_value + mock_client.read_namespaced_pod.return_value = V1Pod(status=V1PodStatus(phase="Succeeded")) + + with patch.object(hook, "_run_post_submit_commands") as mock_cmd: + hook._poll_k8s_driver_via_api() + + mock_cmd.assert_called_once() + @patch("time.sleep") @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") def test_poll_k8s_driver_raises_after_consecutive_api_errors(self, mock_get_client, _): From 60be3a5eed26212a59abd791c2126a6efd410c5e Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 3 Jun 2026 09:37:22 +0530 Subject: [PATCH 5/6] review comments from kaxil --- providers/apache/spark/docs/operators.rst | 4 +++ .../apache/spark/hooks/spark_submit.py | 23 ++++++++++------ .../apache/spark/hooks/test_spark_submit.py | 27 ++++++++++++++++++- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/providers/apache/spark/docs/operators.rst b/providers/apache/spark/docs/operators.rst index 70bc9e0fd181b..249eb6e375e15 100644 --- a/providers/apache/spark/docs/operators.rst +++ b/providers/apache/spark/docs/operators.rst @@ -247,3 +247,7 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full read and delete pods in the driver's namespace; otherwise pod tracking and cleanup will fail. * This path bypasses ``ResumableJobMixin``, so Airflow retries submit a fresh driver instead of reconnecting to an existing one. Set ``execution_timeout`` to bound wall-clock time. +* Pod completion is detected from ``pod.status.phase``. If your driver pods have sidecar + containers (e.g. Istio injection enabled for the driver namespace), the pod phase may not + advance to ``Succeeded`` until all sidecars exit. In that case the poll loop will wait + indefinitely — set ``execution_timeout`` as a hard bound. diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py index bd2ba064d0d97..3d7850316bc8b 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py @@ -270,6 +270,7 @@ def __init__( self._spark_exit_code: int | None = None self._env: dict[str, Any] | None = None self._post_submit_commands: list[str] = list(post_submit_commands) if post_submit_commands else [] + self._post_submit_commands_done: bool = False def _resolve_should_track_driver_status(self) -> bool: """ @@ -625,7 +626,12 @@ def _run_post_submit_commands(self) -> None: Called after the Spark job finishes (success or on_kill). Typical use case is killing sidecars like Istio that don't shut down automatically. Failures are logged as warnings and never raise. + Guaranteed to run at most once per hook instance even if called from both + the poll-loop finally and on_kill (e.g. after a SIGTERM). """ + if self._post_submit_commands_done: + return + self._post_submit_commands_done = True for cmd in self._post_submit_commands: self.log.debug("Running post-submit command: %s", cmd) try: @@ -868,9 +874,6 @@ def _poll_k8s_driver_via_api(self) -> None: namespace = self._connection["namespace"] app_id = self._kubernetes_application_id or pod_name - if not pod_name: - raise ValueError("K8s driver pod name not set; cannot poll status.") - client = kube_client.get_kube_client() poll_interval = max(self._status_poll_interval, 20) if poll_interval != self._status_poll_interval: @@ -888,11 +891,19 @@ def _poll_k8s_driver_via_api(self) -> None: pending_warn_threshold = 10 try: + if not pod_name: + raise ValueError("K8s driver pod name not set; cannot poll status.") while True: try: pod = client.read_namespaced_pod(pod_name, namespace) consecutive_api_errors = 0 except kube_client.ApiException as e: + if e.status == 404: + self.log.info( + "Driver pod %s not found (404); pod was likely deleted by on_kill. Exiting poll loop.", + pod_name, + ) + return consecutive_api_errors += 1 self.log.warning( "ApiException polling pod %s (%d/%d): %s", @@ -943,11 +954,7 @@ def _poll_k8s_driver_via_api(self) -> None: else: consecutive_unknown = 0 time.sleep(poll_interval) - try: - client.delete_namespaced_pod(pod_name, namespace) - self.log.info("Deleted driver pod %s", pod_name) - except kube_client.ApiException: - self.log.warning("Could not delete driver pod %s after completion", pod_name) + self._delete_driver_pod() finally: self._run_post_submit_commands() diff --git a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py index 96337a1aa6265..97b39fe86a6f2 100644 --- a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py @@ -1465,7 +1465,7 @@ def test_poll_k8s_driver_succeeds(self, mock_get_client): with patch.object(hook, "_run_post_submit_commands"): hook._poll_k8s_driver_via_api() - mock_client.delete_namespaced_pod.assert_called_once_with("spark-app-abc-driver", "mynamespace") + assert mock_client.delete_namespaced_pod.call_args.args[:2] == ("spark-app-abc-driver", "mynamespace") @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") def test_poll_k8s_driver_raises_on_failed(self, mock_get_client): @@ -1542,3 +1542,28 @@ def test_poll_k8s_driver_raises_after_consecutive_api_errors(self, mock_get_clie hook._poll_k8s_driver_via_api() assert mock_client.read_namespaced_pod.call_count == 3 + + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_poll_k8s_driver_exits_cleanly_on_404(self, mock_get_client): + """404 from read_namespaced_pod means pod was deleted by on_kill — should return cleanly, not raise.""" + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + + mock_client = mock_get_client.return_value + mock_client.read_namespaced_pod.side_effect = kube_client.ApiException(status=404, reason="Not Found") + + hook._poll_k8s_driver_via_api() + + mock_client.delete_namespaced_pod.assert_not_called() + + @patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.run") + def test_run_post_submit_commands_runs_only_once(self, mock_run): + """Calling _run_post_submit_commands twice must execute commands exactly once.""" + mock_run.return_value = MagicMock(returncode=0, stdout="") + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", post_submit_commands=["echo done"]) + + hook._run_post_submit_commands() + hook._run_post_submit_commands() + + mock_run.assert_called_once() From 34952fa0bd2b38c78a504ce43d4674629bd30628 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Thu, 4 Jun 2026 14:41:54 +0530 Subject: [PATCH 6/6] bad formatting --- .../src/airflow/providers/apache/spark/hooks/spark_submit.py | 3 --- .../airflow/providers/apache/spark/operators/spark_submit.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py index 14ff3e0830953..7cf1f3248adc8 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py @@ -116,7 +116,6 @@ class SparkSubmitHook(BaseHook, LoggingMixin): job finishes (on both success and on_kill). Useful for cleaning up sidecars such as Istio (e.g. ``["curl -X POST localhost:15020/quitquitquit"]``). Each command is executed via the shell; failures produce a warning but do not fail the task. - <<<<<<< HEAD :param track_driver_via_k8s_api: If True (when master is Kubernetes and ``deploy_mode`` is ``cluster``), release the ``spark-submit`` JVM once the driver pod has been created, then poll the Kubernetes API for the pod phase @@ -124,7 +123,6 @@ class SparkSubmitHook(BaseHook, LoggingMixin): controlled by ``status_poll_interval`` with a 20-second minimum. This frees the worker from holding the long-lived submit JVM (~500 MB). Defaults to ``False``. - ======= :param yarn_track_via_rm_api: If True (when master is YARN and ``deploy_mode`` is ``cluster``), release the ``spark-submit`` JVM once the application has been submitted to YARN, then poll the YARN ResourceManager REST API @@ -141,7 +139,6 @@ class SparkSubmitHook(BaseHook, LoggingMixin): ``keytab`` and ``principal`` configured use ``requests-kerberos`` automatically. Defaults to ``None`` (no auth for non-Kerberos connections). - >>>>>>> main """ conn_name_attr = "conn_id" diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index ce88c3748b23c..5321dbbb8befc 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -113,14 +113,12 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator): on keytab for Kerberos login :param post_submit_commands: Optional list of shell commands to run after the Spark job finishes. Useful for cleaning up sidecars such as Istio. Failures produce a warning but do not fail the task. - <<<<<<< HEAD :param track_driver_via_k8s_api: If True (when master is Kubernetes and ``deploy_mode`` is ``cluster``), release the ``spark-submit`` JVM once the driver pod has been created, then poll the Kubernetes API for the pod phase until the application reaches a terminal state. The polling interval is controlled by ``status_poll_interval`` with a 20-second minimum. This frees the worker from holding the long-lived submit JVM. Defaults to ``False``. - ======= :param yarn_track_via_rm_api: If True (when master is YARN and ``deploy_mode`` is ``cluster``), release the ``spark-submit`` JVM once the application has been submitted to YARN, then poll the YARN ResourceManager REST API @@ -136,7 +134,6 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator): omitted, Kerberos-enabled Spark connections with both ``keytab`` and ``principal`` configured use ``requests-kerberos`` automatically. Defaults to ``None`` (no auth for non-Kerberos connections). - >>>>>>> main """ # Generic key used across all Spark deployment modes (standalone driver ID,