diff --git a/providers/apache/spark/docs/operators.rst b/providers/apache/spark/docs/operators.rst index f20d389811ed7..6bdd4bbcdc772 100644 --- a/providers/apache/spark/docs/operators.rst +++ b/providers/apache/spark/docs/operators.rst @@ -215,6 +215,43 @@ See :doc:`connections/spark-submit` for how to configure these fields. Crash recovery in cluster mode requires Airflow 3.3+ (``task_state`` support). On earlier versions the operator falls back to the previous behavior of always submitting fresh. +Tracking driver status via Kubernetes API +"""""""""""""""""""""""""""""""""""""""""" + +When running in Kubernetes cluster mode, ``spark-submit`` blocks for the duration of the job. +The JVM runs processes which does nothing but polling of the pod phase and holds heap space for +the entire duration. This is not ideal for long-running jobs, especially when the driver is idle +for long periods (e.g. waiting for data or user input). + +Set ``track_driver_via_k8s_api=True`` to have the operator track the driver pod status via the +Python Kubernetes client rather than holding ``spark-submit`` open for the full job duration: + +.. code-block:: python + + from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator + + run_spark = SparkSubmitOperator( + task_id="run_spark", + application="local:///opt/spark/examples/jars/spark-examples.jar", + conn_id="spark_k8s", + deploy_mode="cluster", + track_driver_via_k8s_api=True, + ) + +**Requirements** + +* The Spark connection ``master`` must be ``k8s://...`` and ``deploy_mode`` must be ``cluster``. +* Do not set ``spark.kubernetes.submission.waitAppCompletion=true`` in your ``conf`` — this + conflicts with the flag and a ``ValueError`` will be raised at task start. +* The Airflow worker must be able to reach the Kubernetes API server and have permission to + read and delete pods in the driver's namespace; otherwise pod tracking and cleanup will fail. +* This path bypasses ``ResumableJobMixin``, so Airflow retries submit a fresh driver instead of + reconnecting to an existing one. Set ``execution_timeout`` to bound wall-clock time. +* Pod completion is detected from ``pod.status.phase``. If your driver pods have sidecar + containers (e.g. Istio injection enabled for the driver namespace), the pod phase may not + advance to ``Succeeded`` until all sidecars exit. In that case the poll loop will wait + indefinitely — set ``execution_timeout`` as a hard bound. + YARN ResourceManager API tracking """"""""""""""""""""""""""""""""" diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py index 7306563a078e5..7cf1f3248adc8 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py @@ -52,6 +52,8 @@ DEFAULT_SPARK_BINARY = "spark-submit" ALLOWED_SPARK_BINARIES = [DEFAULT_SPARK_BINARY, "spark2-submit", "spark3-submit"] +_K8S_WAIT_APP_COMPLETION_CONF = "spark.kubernetes.submission.waitAppCompletion" + class SparkSubmitHook(BaseHook, LoggingMixin): """ @@ -90,11 +92,11 @@ class SparkSubmitHook(BaseHook, LoggingMixin): :param name: Name of the job (default airflow-spark) :param num_executors: Number of executors to launch :param status_poll_interval: Seconds to wait between polls of driver status in cluster - mode. Used both by the Spark standalone driver-status tracker and (when - ``yarn_track_via_rm_api=True``) by the YARN ResourceManager REST API - polling loop. The YARN ResourceManager REST API polling loop uses at - least 10 seconds to avoid flooding the ResourceManager on long-running - jobs (Default: 1). + mode (Default: 1). Controls three polling loops — each enforces its own minimum: + + - Spark standalone driver-status tracker (no minimum) + - YARN ResourceManager REST API, when ``yarn_track_via_rm_api=True`` (10s minimum) + - Kubernetes API, when ``track_driver_via_k8s_api=True`` (20s minimum) :param application_args: Arguments for the application being submitted :param env_vars: Environment variables for spark-submit. It supports yarn and k8s mode too. @@ -114,6 +116,13 @@ class SparkSubmitHook(BaseHook, LoggingMixin): job finishes (on both success and on_kill). Useful for cleaning up sidecars such as Istio (e.g. ``["curl -X POST localhost:15020/quitquitquit"]``). Each command is executed via the shell; failures produce a warning but do not fail the task. + :param track_driver_via_k8s_api: If True (when master is Kubernetes and + ``deploy_mode`` is ``cluster``), release the ``spark-submit`` JVM once the + driver pod has been created, then poll the Kubernetes API for the pod phase + until the application reaches a terminal state. The polling interval is + controlled by ``status_poll_interval`` with a 20-second minimum. This frees + the worker from holding the long-lived submit JVM (~500 MB). Defaults to + ``False``. :param yarn_track_via_rm_api: If True (when master is YARN and ``deploy_mode`` is ``cluster``), release the ``spark-submit`` JVM once the application has been submitted to YARN, then poll the YARN ResourceManager REST API @@ -257,6 +266,7 @@ def __init__( *, use_krb5ccache: bool = False, post_submit_commands: list[str] | None = None, + track_driver_via_k8s_api: bool = False, yarn_track_via_rm_api: bool = False, yarn_rm_auth: AuthBase | None = None, ) -> None: @@ -302,12 +312,14 @@ def __init__( f"{self._connection['master']} specified by kubernetes dependencies are not installed!" ) + self._track_driver_via_k8s_api = track_driver_via_k8s_api self._should_track_driver_status = self._resolve_should_track_driver_status() self._driver_id: str | None = None self._driver_status: str | None = None self._spark_exit_code: int | None = None self._env: dict[str, Any] | None = None self._post_submit_commands: list[str] = list(post_submit_commands) if post_submit_commands else [] + self._post_submit_commands_done: bool = False self._yarn_track_via_rm_api = yarn_track_via_rm_api self._yarn_rm_auth = yarn_rm_auth # Cached after first successful resolution so the polling loop in @@ -326,6 +338,34 @@ def _resolve_should_track_driver_status(self) -> bool: """ return "spark://" in self._connection["master"] and self._connection["deploy_mode"] == "cluster" + def _should_track_driver_via_k8s_api(self) -> bool: + return ( + self._track_driver_via_k8s_api + and self._is_kubernetes + and self._connection["deploy_mode"] == "cluster" + ) + + def _validate_track_driver_via_k8s_api_config(self) -> None: + if not self._is_kubernetes: + raise ValueError( + "`track_driver_via_k8s_api=True` requires Spark master to be Kubernetes (k8s://...)." + ) + if self._connection["deploy_mode"] != "cluster": + raise ValueError( + "`track_driver_via_k8s_api=True` requires `deploy_mode='cluster'`; " + f"got deploy_mode={self._connection['deploy_mode']!r}." + ) + if not self._connection.get("namespace"): + raise ValueError( + "`track_driver_via_k8s_api=True` requires a namespace; " + "set it in the connection extra as `namespace` or via `spark.kubernetes.namespace` in conf." + ) + if str(self._conf.get(_K8S_WAIT_APP_COMPLETION_CONF, "")).lower() == "true": + raise ValueError( + f"`track_driver_via_k8s_api=True` is incompatible with " + f"`{_K8S_WAIT_APP_COMPLETION_CONF}=true`; remove it from your conf or set it to 'false'." + ) + def _should_track_yarn_application_via_rm_api(self) -> bool: """Return whether this submit should switch to YARN RM REST API polling.""" return self._yarn_track_via_rm_api and self._is_yarn and self._connection["deploy_mode"] == "cluster" @@ -581,6 +621,10 @@ def _build_spark_common_args(self) -> list[str]: if self._connection["deploy_mode"]: args += ["--deploy-mode", self._connection["deploy_mode"]] + if self._should_track_driver_via_k8s_api(): + if _K8S_WAIT_APP_COMPLETION_CONF not in self._conf: + args += ["--conf", f"{_K8S_WAIT_APP_COMPLETION_CONF}=false"] + return args def _build_spark_submit_command(self, application: str) -> list[str]: @@ -665,7 +709,12 @@ def _run_post_submit_commands(self) -> None: Called after the Spark job finishes (success or on_kill). Typical use case is killing sidecars like Istio that don't shut down automatically. Failures are logged as warnings and never raise. + Guaranteed to run at most once per hook instance even if called from both + the poll-loop finally and on_kill (e.g. after a SIGTERM). """ + if self._post_submit_commands_done: + return + self._post_submit_commands_done = True for cmd in self._post_submit_commands: self.log.debug("Running post-submit command: %s", cmd) try: @@ -719,8 +768,14 @@ def submit(self, application: str = "", **kwargs: Any) -> str | None: # Check spark-submit return code. In Kubernetes mode, also check the value # of exit code in the log, as it may differ. + # When polling via K8s API, spark-submit exits after pod creation (waitAppCompletion=false) + # so _spark_exit_code is never set by the JVM watcher — skip that check entirely. try: - if returncode or (self._is_kubernetes and self._spark_exit_code != 0): + if returncode or ( + self._is_kubernetes + and not self._should_track_driver_via_k8s_api() + and self._spark_exit_code != 0 + ): if self._is_kubernetes: raise AirflowException( f"Cannot execute: {self._mask_cmd(spark_submit_cmd)}. Error code is: {returncode}. " @@ -744,10 +799,11 @@ def submit(self, application: str = "", **kwargs: Any) -> str | None: "No driver id is known: something went wrong when executing the spark submit command" ) finally: - # In cluster mode with driver tracking, the operator calls poll_until_complete - # after submit() returns, so post_submit_commands are deferred there to preserve - # the "runs after job finishes" contract. In all other modes, run them here. - if not self._should_track_driver_status: + # K8s-API tracking defers post-submit commands to _poll_k8s_driver_via_api's finally + # block so they run once after the driver reaches a terminal state. Spark cluster-mode + # driver tracking defers them to poll_until_complete for the same reason. All other + # modes run them here, immediately after spark-submit exits. + if not self._should_track_driver_status and not self._should_track_driver_via_k8s_api(): self._run_post_submit_commands() return self._driver_id @@ -778,10 +834,17 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None: # If we run Kubernetes cluster mode, we want to extract the driver pod id # from the logs so we can kill the application when we stop it unexpectedly elif self._is_kubernetes: + # Two log formats exist across Spark versions: + # "pod name: -driver" and "submission ID spark:-driver" match_driver_pod = re.search(r"\s*pod name: ((.+?)-([a-z0-9]+)-driver$)", line) if match_driver_pod: self._kubernetes_driver_pod = match_driver_pod.group(1) self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod) + if not self._kubernetes_driver_pod: + match_submission_id = re.search(r"submission ID spark:(.+?-driver)", line) + if match_submission_id: + self._kubernetes_driver_pod = match_submission_id.group(1) + self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod) match_application_id = re.search(r"\s*spark-app-selector -> (spark-([a-z0-9]+)), ", line) if match_application_id: @@ -1047,6 +1110,96 @@ def _start_driver_status_tracking(self) -> None: f"returncode = {returncode}" ) + def _poll_k8s_driver_via_api(self) -> None: + """Poll the K8s driver pod phase until it reaches a terminal state.""" + pod_name = self._kubernetes_driver_pod + namespace = self._connection["namespace"] + app_id = self._kubernetes_application_id or pod_name + + client = kube_client.get_kube_client() + poll_interval = max(self._status_poll_interval, 20) + if poll_interval != self._status_poll_interval: + self.log.info( + "status_poll_interval=%ds is below the 20s minimum for K8s API polling; using 20s.", + self._status_poll_interval, + ) + # Mirror `missed_job_status_reports` / `max_missed_job_status_reports` from + # `_start_driver_status_tracking`: tolerate transient failures before giving up. + consecutive_unknown = 0 + max_consecutive_unknown = 3 + consecutive_api_errors = 0 + max_consecutive_api_errors = 3 + consecutive_pending = 0 + pending_warn_threshold = 10 + + try: + if not pod_name: + raise ValueError("K8s driver pod name not set; cannot poll status.") + while True: + try: + pod = client.read_namespaced_pod(pod_name, namespace) + consecutive_api_errors = 0 + except kube_client.ApiException as e: + if e.status == 404: + self.log.info( + "Driver pod %s not found (404); pod was likely deleted by on_kill. Exiting poll loop.", + pod_name, + ) + return + consecutive_api_errors += 1 + self.log.warning( + "ApiException polling pod %s (%d/%d): %s", + pod_name, + consecutive_api_errors, + max_consecutive_api_errors, + e, + ) + if consecutive_api_errors >= max_consecutive_api_errors: + raise RuntimeError( + f"K8s API unreachable after {consecutive_api_errors} consecutive errors " + f"while polling {app_id}; giving up." + ) from e + time.sleep(poll_interval) + continue + + phase = pod.status.phase or "Initializing" + self.log.info("Application status for %s (phase: %s)", app_id, phase) + if phase == "Succeeded": + break + if phase == "Failed": + container_state = "" + if pod.status.container_statuses: + cs = pod.status.container_statuses[0] + if cs.state and cs.state.terminated: + container_state = f" exit_code={cs.state.terminated.exit_code} reason={cs.state.terminated.reason}" + raise RuntimeError(f"Spark application {app_id} failed (phase=Failed{container_state})") + if phase == "Pending": + consecutive_pending += 1 + if consecutive_pending == pending_warn_threshold: + self.log.warning( + "Driver pod %s has been Pending for %d polls (~%ds); " + "it may be unschedulable. Continuing to wait — set execution_timeout to bound wait time.", + pod_name, + consecutive_pending, + consecutive_pending * poll_interval, + ) + else: + consecutive_pending = 0 + + if phase == "Unknown": + consecutive_unknown += 1 + if consecutive_unknown >= max_consecutive_unknown: + raise RuntimeError( + f"Spark application {app_id} reported Unknown phase " + f"{consecutive_unknown} times consecutively; giving up." + ) + else: + consecutive_unknown = 0 + time.sleep(poll_interval) + self._delete_driver_pod() + finally: + self._run_post_submit_commands() + def _build_spark_driver_kill_command(self) -> list[str]: """ Construct the spark-submit command to kill a driver. @@ -1067,6 +1220,25 @@ def _build_spark_driver_kill_command(self) -> list[str]: return connection_cmd + def _delete_driver_pod(self) -> None: + """Delete the Kubernetes driver pod, logging a warning on failure.""" + import kubernetes + + self.log.info("Deleting driver pod %s on Kubernetes", self._kubernetes_driver_pod) + try: + client = kube_client.get_kube_client() + client.delete_namespaced_pod( + self._kubernetes_driver_pod, + self._connection["namespace"], + body=kubernetes.client.V1DeleteOptions(), + pretty=True, + ) + self.log.info("Deleted driver pod %s", self._kubernetes_driver_pod) + except kube_client.ApiException: + self.log.exception( + "Exception when attempting to delete driver pod %s", self._kubernetes_driver_pod + ) + def on_kill(self) -> None: """Kill Spark submit command.""" self.log.debug("Kill Command is being called") @@ -1080,6 +1252,11 @@ def on_kill(self) -> None: "Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait() ) + if self._should_track_driver_via_k8s_api() and self._kubernetes_driver_pod: + # spark-submit exits early under waitAppCompletion=false, so _submit_sp.poll() is + # not None during the poll loop — the deletion block below is skipped on kill. + self._delete_driver_pod() + if self._submit_sp and self._submit_sp.poll() is None: self.log.info("Sending kill signal to %s", self._connection["spark_binary"]) self._submit_sp.kill() @@ -1111,24 +1288,7 @@ def on_kill(self) -> None: self.log.info("YARN app killed with return code: %s", yarn_kill.wait()) if self._kubernetes_driver_pod: - self.log.info("Killing pod %s on Kubernetes", self._kubernetes_driver_pod) - - # Currently only instantiate Kubernetes client for killing a spark pod. - try: - import kubernetes - - client = kube_client.get_kube_client() - api_response = client.delete_namespaced_pod( - self._kubernetes_driver_pod, - self._connection["namespace"], - body=kubernetes.client.V1DeleteOptions(), - pretty=True, - ) - - self.log.info("Spark on K8s killed with response: %s", api_response) - - except kube_client.ApiException: - self.log.exception("Exception when attempting to kill Spark on K8s") + self._delete_driver_pod() # Opt-in REST kill path — uses the same RM endpoint as polling, no # `yarn` CLI dependency on the worker. Independent of `_submit_sp` diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index ea7b4a8e4ef89..5321dbbb8befc 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -113,6 +113,12 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator): on keytab for Kerberos login :param post_submit_commands: Optional list of shell commands to run after the Spark job finishes. Useful for cleaning up sidecars such as Istio. Failures produce a warning but do not fail the task. + :param track_driver_via_k8s_api: If True (when master is Kubernetes and ``deploy_mode`` + is ``cluster``), release the ``spark-submit`` JVM once the driver pod has been + created, then poll the Kubernetes API for the pod phase until the application + reaches a terminal state. The polling interval is controlled by + ``status_poll_interval`` with a 20-second minimum. This frees the worker from + holding the long-lived submit JVM. Defaults to ``False``. :param yarn_track_via_rm_api: If True (when master is YARN and ``deploy_mode`` is ``cluster``), release the ``spark-submit`` JVM once the application has been submitted to YARN, then poll the YARN ResourceManager REST API @@ -188,6 +194,7 @@ def __init__( use_krb5ccache: bool = False, post_submit_commands: list[str] | None = None, reconnect_on_retry: bool = True, + track_driver_via_k8s_api: bool = False, yarn_track_via_rm_api: bool = False, yarn_rm_auth: AuthBase | None = None, openlineage_inject_parent_job_info: bool = conf.getboolean( @@ -236,6 +243,7 @@ def __init__( self._yarn_rm_auth = yarn_rm_auth self.reconnect_on_retry = reconnect_on_retry + self._track_driver_via_k8s_api = track_driver_via_k8s_api self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info self._openlineage_inject_transport_info = openlineage_inject_transport_info @@ -251,6 +259,8 @@ def execute(self, context: Context) -> None: if self._hook is None: self._hook = self._get_hook() hook = self._hook + if self._track_driver_via_k8s_api: + hook._validate_track_driver_via_k8s_api_config() if hook._should_track_driver_status: if self.reconnect_on_retry: return self.execute_resumable(context) @@ -258,6 +268,12 @@ def execute(self, context: Context) -> None: driver_id = self.submit_job(context) self.poll_until_complete(driver_id, context) return self.get_job_result(driver_id, context) + if hook._should_track_driver_via_k8s_api(): + # TODO: Wire into execute_resumable() via ResumableJobMixin + # (fill submit_job / poll_until_complete K8s stubs) to enable crash recovery. + hook.submit(self.application) + hook._poll_k8s_driver_via_api() + return hook.submit(self.application) def submit_job(self, context: Context) -> str: @@ -402,6 +418,7 @@ def _get_hook(self) -> SparkSubmitHook: deploy_mode=self._deploy_mode, use_krb5ccache=self._use_krb5ccache, post_submit_commands=self.post_submit_commands, + track_driver_via_k8s_api=self._track_driver_via_k8s_api, yarn_track_via_rm_api=self._yarn_track_via_rm_api, yarn_rm_auth=self._yarn_rm_auth, ) diff --git a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py index 301f720443dbc..f4a610a940814 100644 --- a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py @@ -24,11 +24,14 @@ from types import ModuleType from unittest.mock import MagicMock, call, mock_open, patch +import kubernetes import pytest import requests +from kubernetes.client import V1Pod, V1PodStatus from airflow.models import Connection from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook +from airflow.providers.cncf.kubernetes import kube_client from airflow.providers.common.compat.sdk import AirflowException @@ -101,6 +104,14 @@ def setup_connections(self, create_connection_without_db): extra='{"deploy-mode": "client", "namespace": "mynamespace"}', ) ) + create_connection_without_db( + Connection( + conn_id="spark_k8s_cluster_no_namespace", + conn_type="spark", + host="k8s://https://k8s-master", + extra='{"deploy-mode": "cluster"}', + ) + ) create_connection_without_db( Connection(conn_id="spark_default_mesos", conn_type="spark", host="mesos://host", port=5050) ) @@ -930,6 +941,18 @@ def test_process_spark_submit_log_k8s_spark_3(self): # Then assert hook._spark_exit_code == 999 + def test_process_spark_submit_log_k8s_submission_id_format(self): + hook = SparkSubmitHook(conn_id="spark_k8s_cluster") + log_lines = [ + "INFO Client: Deployed Spark application arrow-spark with application ID " + "spark-1e22d65826b74ac2927249b0e607ed54 and submission ID " + "spark:arrow-spark-c8e2e29e73db9c93-driver into Kubernetes", + ] + + hook._process_spark_submit_log(log_lines) + + assert hook._kubernetes_driver_pod == "arrow-spark-c8e2e29e73db9c93-driver" + def test_process_spark_client_mode_submit_log_k8s(self): # Given hook = SparkSubmitHook(conn_id="spark_k8s_client") @@ -1146,6 +1169,29 @@ def test_k8s_process_on_kill(self, mock_popen, mock_client_method): "spark-pi-edf2ace37be7353a958b38733a12f8e6-driver", "mynamespace", **kwargs ) + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_on_kill_deletes_pod_when_k8s_api_tracking_and_submit_sp_already_exited(self, mock_get_client): + """on_kill must delete the driver pod when K8s-API tracking is active even if spark-submit + has already exited. + """ + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + hook._submit_sp = MagicMock() + # spark-submit already exited + hook._submit_sp.poll.return_value = 0 + + mock_client = mock_get_client.return_value + + hook.on_kill() + + mock_client.delete_namespaced_pod.assert_called_once_with( + "spark-app-abc-driver", + "mynamespace", + body=kubernetes.client.V1DeleteOptions(), + pretty=True, + ) + @pytest.mark.parametrize( ("command", "expected"), [ @@ -1347,7 +1393,6 @@ def test_create_keytab_path_from_base64_keytab_with_existing_keytab( def test_run_post_submit_commands_success(self, mock_run): """Test that post_submit_commands are run with shell=False and shlex.split.""" import subprocess - from unittest.mock import MagicMock mock_result = MagicMock(spec=subprocess.CompletedProcess) mock_result.returncode = 0 @@ -1375,7 +1420,6 @@ def test_run_post_submit_commands_success(self, mock_run): def test_run_post_submit_commands_nonzero_exit_warns(self, mock_run): """Test that a non-zero exit code logs a warning but does not raise.""" import subprocess - from unittest.mock import MagicMock mock_result = MagicMock(spec=subprocess.CompletedProcess) mock_result.returncode = 1 @@ -1412,6 +1456,163 @@ def test_post_submit_commands_none_gives_empty_list(self): hook = SparkSubmitHook(conn_id="") assert hook._post_submit_commands == [] + @pytest.mark.parametrize( + ("conn_id", "flag", "expected"), + [ + ("spark_k8s_cluster", False, False), + ("spark_k8s_cluster", True, True), + ("spark_k8s_client", True, False), + ], + ) + def test_should_track_driver_via_k8s_api(self, conn_id, flag, expected): + hook = SparkSubmitHook(conn_id=conn_id, track_driver_via_k8s_api=flag) + assert hook._should_track_driver_via_k8s_api() is expected + + @pytest.mark.parametrize( + ("conn_id", "match"), + [ + ("spark_yarn_cluster", "requires Spark master to be Kubernetes"), + ("spark_k8s_client", "requires `deploy_mode='cluster'`"), + ("spark_k8s_cluster_no_namespace", "requires a namespace"), + ], + ) + def test_validate_track_driver_via_k8s_api_raises(self, conn_id, match): + hook = SparkSubmitHook(conn_id=conn_id, track_driver_via_k8s_api=True) + with pytest.raises(ValueError, match=match): + hook._validate_track_driver_via_k8s_api_config() + + def test_validate_track_driver_via_k8s_api_raises_on_conflicting_user_conf(self): + hook = SparkSubmitHook( + conn_id="spark_k8s_cluster", + track_driver_via_k8s_api=True, + conf={"spark.kubernetes.submission.waitAppCompletion": "true"}, + ) + with pytest.raises(ValueError, match="incompatible with.*waitAppCompletion=true"): + hook._validate_track_driver_via_k8s_api_config() + + def test_conf_injection_adds_wait_app_completion(self): + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + cmd = hook._build_spark_submit_command("app.jar") + conf_pairs = [cmd[i + 1] for i, v in enumerate(cmd) if v == "--conf"] + assert "spark.kubernetes.submission.waitAppCompletion=false" in conf_pairs + + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_poll_k8s_driver_succeeds(self, mock_get_client): + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + + mock_client = mock_get_client.return_value + running_pod = V1Pod(status=V1PodStatus(phase="Running")) + succeeded_pod = V1Pod(status=V1PodStatus(phase="Succeeded")) + mock_client.read_namespaced_pod.side_effect = [running_pod, succeeded_pod] + + with patch.object(hook, "_run_post_submit_commands"): + hook._poll_k8s_driver_via_api() + + assert mock_client.delete_namespaced_pod.call_args.args[:2] == ("spark-app-abc-driver", "mynamespace") + + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_poll_k8s_driver_raises_on_failed(self, mock_get_client): + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + + mock_client = mock_get_client.return_value + failed_pod = V1Pod(status=V1PodStatus(phase="Failed")) + mock_client.read_namespaced_pod.return_value = failed_pod + + with pytest.raises(RuntimeError, match="phase=Failed"): + hook._poll_k8s_driver_via_api() + + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_poll_k8s_driver_raises_after_consecutive_unknown(self, mock_get_client): + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + + mock_client = mock_get_client.return_value + mock_client.read_namespaced_pod.return_value = V1Pod(status=V1PodStatus(phase="Unknown")) + + with patch("time.sleep"), pytest.raises(RuntimeError, match="Unknown phase"): + hook._poll_k8s_driver_via_api() + + # assert that it was polled minimum 3 times to confirm the Unknown status before raising + assert mock_client.read_namespaced_pod.call_count == 3 + + @patch("time.sleep") + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_poll_k8s_driver_tolerates_transient_api_errors(self, mock_get_client, _): + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + + mock_client = mock_get_client.return_value + api_error = kube_client.ApiException(status=500, reason="Internal Server Error") + succeeded_pod = V1Pod(status=V1PodStatus(phase="Succeeded")) + mock_client.read_namespaced_pod.side_effect = [api_error, api_error, succeeded_pod] + + with patch.object(hook, "_run_post_submit_commands"): + hook._poll_k8s_driver_via_api() + + assert mock_client.read_namespaced_pod.call_count == 3 + + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_post_submit_commands_run_exactly_once_on_k8s_path(self, mock_get_client): + """_run_post_submit_commands must fire exactly once: in _poll_k8s_driver_via_api finally.""" + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + + mock_client = mock_get_client.return_value + mock_client.read_namespaced_pod.return_value = V1Pod(status=V1PodStatus(phase="Succeeded")) + + with patch.object(hook, "_run_post_submit_commands") as mock_cmd: + hook._poll_k8s_driver_via_api() + + mock_cmd.assert_called_once() + + @patch("time.sleep") + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_poll_k8s_driver_raises_after_consecutive_api_errors(self, mock_get_client, _): + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + + mock_client = mock_get_client.return_value + api_error = kube_client.ApiException(status=500, reason="Internal Server Error") + mock_client.read_namespaced_pod.side_effect = api_error + + with pytest.raises(RuntimeError, match="K8s API unreachable"): + hook._poll_k8s_driver_via_api() + + assert mock_client.read_namespaced_pod.call_count == 3 + + @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") + def test_poll_k8s_driver_exits_cleanly_on_404(self, mock_get_client): + """404 from read_namespaced_pod means pod was deleted by on_kill — should return cleanly, not raise.""" + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", track_driver_via_k8s_api=True) + hook._kubernetes_driver_pod = "spark-app-abc-driver" + hook._kubernetes_application_id = "spark-abc" + + mock_client = mock_get_client.return_value + mock_client.read_namespaced_pod.side_effect = kube_client.ApiException(status=404, reason="Not Found") + + hook._poll_k8s_driver_via_api() + + mock_client.delete_namespaced_pod.assert_not_called() + + @patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.run") + def test_run_post_submit_commands_runs_only_once(self, mock_run): + """Calling _run_post_submit_commands twice must execute commands exactly once.""" + mock_run.return_value = MagicMock(returncode=0, stdout="") + hook = SparkSubmitHook(conn_id="spark_k8s_cluster", post_submit_commands=["echo done"]) + + hook._run_post_submit_commands() + hook._run_post_submit_commands() + + mock_run.assert_called_once() + _YARN_LOG_LINES = [ "INFO Client: Requesting a new application from cluster with 1 NodeManagers", "INFO Client: Uploading resource file:/tmp/lib.zip -> " diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index e18de1dd57648..47cada84ce6db 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -733,4 +733,43 @@ def simulate_failed_tracking(): with pytest.raises(RuntimeError, match="FAILED"): operator.poll_until_complete("driver-001", {}) - assert post_submit_called, "_run_post_submit_commands must be called even on driver failure" + +class TestSparkSubmitOperatorK8sTracking: + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + self.dag = DAG("test_k8s_tracking_dag", schedule=None, default_args=args) + + def _make_operator(self, **kwargs): + return SparkSubmitOperator(task_id="test", dag=self.dag, application="test.jar", **kwargs) + + def _make_k8s_hook(self): + hook = MagicMock() + hook._should_track_driver_status = False + hook._should_track_driver_via_k8s_api.return_value = True + return hook + + def test_execute_calls_submit_then_poll_when_flag_set(self): + operator = self._make_operator(track_driver_via_k8s_api=True) + hook = self._make_k8s_hook() + operator._hook = hook + call_order = [] + hook.submit.side_effect = lambda *a, **kw: call_order.append("submit") + hook._poll_k8s_driver_via_api.side_effect = lambda: call_order.append("poll") + + operator.execute(context={}) + + hook.submit.assert_called_once_with("test.jar") + hook._poll_k8s_driver_via_api.assert_called_once() + assert call_order == ["submit", "poll"] + + def test_execute_falls_through_to_plain_submit_when_flag_off(self): + operator = self._make_operator(track_driver_via_k8s_api=False) + hook = MagicMock() + hook._should_track_driver_status = False + hook._should_track_driver_via_k8s_api.return_value = False + operator._hook = hook + + operator.execute(context={}) + + hook.submit.assert_called_once_with("test.jar") + hook._poll_k8s_driver_via_api.assert_not_called()