Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions providers/apache/spark/docs/operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,43 @@ See :doc:`connections/spark-submit` for how to configure these fields.
Crash recovery in cluster mode requires Airflow 3.3+ (``task_state`` support). On earlier
versions the operator falls back to the previous behavior of always submitting fresh.

Tracking driver status via Kubernetes API
""""""""""""""""""""""""""""""""""""""""""

When running in Kubernetes cluster mode, ``spark-submit`` blocks for the duration of the job.
The JVM runs processes which does nothing but polling of the pod phase and holds heap space for
the entire duration. This is not ideal for long-running jobs, especially when the driver is idle
for long periods (e.g. waiting for data or user input).

Set ``track_driver_via_k8s_api=True`` to have the operator track the driver pod status via the
Python Kubernetes client rather than holding ``spark-submit`` open for the full job duration:

.. code-block:: python

from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator

run_spark = SparkSubmitOperator(
task_id="run_spark",
application="local:///opt/spark/examples/jars/spark-examples.jar",
conn_id="spark_k8s",
deploy_mode="cluster",
track_driver_via_k8s_api=True,
)

**Requirements**

* The Spark connection ``master`` must be ``k8s://...`` and ``deploy_mode`` must be ``cluster``.
* Do not set ``spark.kubernetes.submission.waitAppCompletion=true`` in your ``conf`` — this
conflicts with the flag and a ``ValueError`` will be raised at task start.
* The Airflow worker must be able to reach the Kubernetes API server and have permission to
read and delete pods in the driver's namespace; otherwise pod tracking and cleanup will fail.
* This path bypasses ``ResumableJobMixin``, so Airflow retries submit a fresh driver instead of
reconnecting to an existing one. Set ``execution_timeout`` to bound wall-clock time.
* Pod completion is detected from ``pod.status.phase``. If your driver pods have sidecar
containers (e.g. Istio injection enabled for the driver namespace), the pod phase may not
advance to ``Succeeded`` until all sidecars exit. In that case the poll loop will wait
indefinitely — set ``execution_timeout`` as a hard bound.

YARN ResourceManager API tracking
"""""""""""""""""""""""""""""""""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
DEFAULT_SPARK_BINARY = "spark-submit"
ALLOWED_SPARK_BINARIES = [DEFAULT_SPARK_BINARY, "spark2-submit", "spark3-submit"]

_K8S_WAIT_APP_COMPLETION_CONF = "spark.kubernetes.submission.waitAppCompletion"


class SparkSubmitHook(BaseHook, LoggingMixin):
"""
Expand Down Expand Up @@ -90,11 +92,11 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
:param name: Name of the job (default airflow-spark)
:param num_executors: Number of executors to launch
:param status_poll_interval: Seconds to wait between polls of driver status in cluster
mode. Used both by the Spark standalone driver-status tracker and (when
``yarn_track_via_rm_api=True``) by the YARN ResourceManager REST API
polling loop. The YARN ResourceManager REST API polling loop uses at
least 10 seconds to avoid flooding the ResourceManager on long-running
jobs (Default: 1).
mode (Default: 1). Controls three polling loops — each enforces its own minimum:

- Spark standalone driver-status tracker (no minimum)
- YARN ResourceManager REST API, when ``yarn_track_via_rm_api=True`` (10s minimum)
- Kubernetes API, when ``track_driver_via_k8s_api=True`` (20s minimum)
:param application_args: Arguments for the application being submitted
:param env_vars: Environment variables for spark-submit. It
supports yarn and k8s mode too.
Expand All @@ -114,6 +116,13 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
job finishes (on both success and on_kill). Useful for cleaning up sidecars such
as Istio (e.g. ``["curl -X POST localhost:15020/quitquitquit"]``). Each command
is executed via the shell; failures produce a warning but do not fail the task.
:param track_driver_via_k8s_api: If True (when master is Kubernetes and
``deploy_mode`` is ``cluster``), release the ``spark-submit`` JVM once the
driver pod has been created, then poll the Kubernetes API for the pod phase
until the application reaches a terminal state. The polling interval is
controlled by ``status_poll_interval`` with a 20-second minimum. This frees
the worker from holding the long-lived submit JVM (~500 MB). Defaults to
``False``.
:param yarn_track_via_rm_api: If True (when master is YARN and ``deploy_mode``
is ``cluster``), release the ``spark-submit`` JVM once the application has
been submitted to YARN, then poll the YARN ResourceManager REST API
Expand Down Expand Up @@ -257,6 +266,7 @@ def __init__(
*,
use_krb5ccache: bool = False,
post_submit_commands: list[str] | None = None,
track_driver_via_k8s_api: bool = False,
yarn_track_via_rm_api: bool = False,
yarn_rm_auth: AuthBase | None = None,
) -> None:
Expand Down Expand Up @@ -302,12 +312,14 @@ def __init__(
f"{self._connection['master']} specified by kubernetes dependencies are not installed!"
)

self._track_driver_via_k8s_api = track_driver_via_k8s_api
self._should_track_driver_status = self._resolve_should_track_driver_status()
self._driver_id: str | None = None
self._driver_status: str | None = None
self._spark_exit_code: int | None = None
self._env: dict[str, Any] | None = None
self._post_submit_commands: list[str] = list(post_submit_commands) if post_submit_commands else []
self._post_submit_commands_done: bool = False
self._yarn_track_via_rm_api = yarn_track_via_rm_api
self._yarn_rm_auth = yarn_rm_auth
# Cached after first successful resolution so the polling loop in
Expand All @@ -326,6 +338,34 @@ def _resolve_should_track_driver_status(self) -> bool:
"""
return "spark://" in self._connection["master"] and self._connection["deploy_mode"] == "cluster"

def _should_track_driver_via_k8s_api(self) -> bool:
return (
self._track_driver_via_k8s_api
and self._is_kubernetes
and self._connection["deploy_mode"] == "cluster"
)

def _validate_track_driver_via_k8s_api_config(self) -> None:
if not self._is_kubernetes:
raise ValueError(
"`track_driver_via_k8s_api=True` requires Spark master to be Kubernetes (k8s://...)."
)
if self._connection["deploy_mode"] != "cluster":
raise ValueError(
"`track_driver_via_k8s_api=True` requires `deploy_mode='cluster'`; "
f"got deploy_mode={self._connection['deploy_mode']!r}."
)
if not self._connection.get("namespace"):
raise ValueError(
"`track_driver_via_k8s_api=True` requires a namespace; "
"set it in the connection extra as `namespace` or via `spark.kubernetes.namespace` in conf."
)
if str(self._conf.get(_K8S_WAIT_APP_COMPLETION_CONF, "")).lower() == "true":
raise ValueError(
f"`track_driver_via_k8s_api=True` is incompatible with "
f"`{_K8S_WAIT_APP_COMPLETION_CONF}=true`; remove it from your conf or set it to 'false'."
)

def _should_track_yarn_application_via_rm_api(self) -> bool:
"""Return whether this submit should switch to YARN RM REST API polling."""
return self._yarn_track_via_rm_api and self._is_yarn and self._connection["deploy_mode"] == "cluster"
Expand Down Expand Up @@ -581,6 +621,10 @@ def _build_spark_common_args(self) -> list[str]:
if self._connection["deploy_mode"]:
args += ["--deploy-mode", self._connection["deploy_mode"]]

if self._should_track_driver_via_k8s_api():
if _K8S_WAIT_APP_COMPLETION_CONF not in self._conf:
args += ["--conf", f"{_K8S_WAIT_APP_COMPLETION_CONF}=false"]

return args

def _build_spark_submit_command(self, application: str) -> list[str]:
Expand Down Expand Up @@ -665,7 +709,12 @@ def _run_post_submit_commands(self) -> None:
Called after the Spark job finishes (success or on_kill). Typical use case
is killing sidecars like Istio that don't shut down automatically.
Failures are logged as warnings and never raise.
Guaranteed to run at most once per hook instance even if called from both
the poll-loop finally and on_kill (e.g. after a SIGTERM).
"""
if self._post_submit_commands_done:
return
self._post_submit_commands_done = True
for cmd in self._post_submit_commands:
self.log.debug("Running post-submit command: %s", cmd)
try:
Expand Down Expand Up @@ -719,8 +768,14 @@ def submit(self, application: str = "", **kwargs: Any) -> str | None:

# Check spark-submit return code. In Kubernetes mode, also check the value
# of exit code in the log, as it may differ.
# When polling via K8s API, spark-submit exits after pod creation (waitAppCompletion=false)
# so _spark_exit_code is never set by the JVM watcher — skip that check entirely.
try:
if returncode or (self._is_kubernetes and self._spark_exit_code != 0):
if returncode or (
self._is_kubernetes
and not self._should_track_driver_via_k8s_api()
and self._spark_exit_code != 0
):
if self._is_kubernetes:
raise AirflowException(
f"Cannot execute: {self._mask_cmd(spark_submit_cmd)}. Error code is: {returncode}. "
Expand All @@ -744,10 +799,11 @@ def submit(self, application: str = "", **kwargs: Any) -> str | None:
"No driver id is known: something went wrong when executing the spark submit command"
)
finally:
# In cluster mode with driver tracking, the operator calls poll_until_complete
# after submit() returns, so post_submit_commands are deferred there to preserve
# the "runs after job finishes" contract. In all other modes, run them here.
if not self._should_track_driver_status:
# K8s-API tracking defers post-submit commands to _poll_k8s_driver_via_api's finally
# block so they run once after the driver reaches a terminal state. Spark cluster-mode
# driver tracking defers them to poll_until_complete for the same reason. All other
# modes run them here, immediately after spark-submit exits.
if not self._should_track_driver_status and not self._should_track_driver_via_k8s_api():
self._run_post_submit_commands()

return self._driver_id
Expand Down Expand Up @@ -778,10 +834,17 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None:
# If we run Kubernetes cluster mode, we want to extract the driver pod id
# from the logs so we can kill the application when we stop it unexpectedly
elif self._is_kubernetes:
# Two log formats exist across Spark versions:
# "pod name: <name>-driver" and "submission ID spark:<name>-driver"
match_driver_pod = re.search(r"\s*pod name: ((.+?)-([a-z0-9]+)-driver$)", line)
if match_driver_pod:
self._kubernetes_driver_pod = match_driver_pod.group(1)
self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod)
if not self._kubernetes_driver_pod:
match_submission_id = re.search(r"submission ID spark:(.+?-driver)", line)
if match_submission_id:
self._kubernetes_driver_pod = match_submission_id.group(1)
self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod)

match_application_id = re.search(r"\s*spark-app-selector -> (spark-([a-z0-9]+)), ", line)
if match_application_id:
Expand Down Expand Up @@ -1047,6 +1110,96 @@ def _start_driver_status_tracking(self) -> None:
f"returncode = {returncode}"
)

def _poll_k8s_driver_via_api(self) -> None:
"""Poll the K8s driver pod phase until it reaches a terminal state."""
pod_name = self._kubernetes_driver_pod
namespace = self._connection["namespace"]
app_id = self._kubernetes_application_id or pod_name

client = kube_client.get_kube_client()
poll_interval = max(self._status_poll_interval, 20)
if poll_interval != self._status_poll_interval:
self.log.info(
"status_poll_interval=%ds is below the 20s minimum for K8s API polling; using 20s.",
self._status_poll_interval,
)
# Mirror `missed_job_status_reports` / `max_missed_job_status_reports` from
# `_start_driver_status_tracking`: tolerate transient failures before giving up.
consecutive_unknown = 0
max_consecutive_unknown = 3
consecutive_api_errors = 0
max_consecutive_api_errors = 3
consecutive_pending = 0
pending_warn_threshold = 10

try:
if not pod_name:
raise ValueError("K8s driver pod name not set; cannot poll status.")
while True:
try:
pod = client.read_namespaced_pod(pod_name, namespace)
consecutive_api_errors = 0
except kube_client.ApiException as e:
if e.status == 404:
self.log.info(
"Driver pod %s not found (404); pod was likely deleted by on_kill. Exiting poll loop.",
pod_name,
)
return
consecutive_api_errors += 1
self.log.warning(
"ApiException polling pod %s (%d/%d): %s",
pod_name,
consecutive_api_errors,
max_consecutive_api_errors,
e,
)
if consecutive_api_errors >= max_consecutive_api_errors:
raise RuntimeError(
f"K8s API unreachable after {consecutive_api_errors} consecutive errors "
f"while polling {app_id}; giving up."
) from e
time.sleep(poll_interval)
continue

phase = pod.status.phase or "Initializing"
self.log.info("Application status for %s (phase: %s)", app_id, phase)
if phase == "Succeeded":
break
if phase == "Failed":
container_state = ""
if pod.status.container_statuses:
cs = pod.status.container_statuses[0]
if cs.state and cs.state.terminated:
container_state = f" exit_code={cs.state.terminated.exit_code} reason={cs.state.terminated.reason}"
raise RuntimeError(f"Spark application {app_id} failed (phase=Failed{container_state})")
if phase == "Pending":
consecutive_pending += 1
if consecutive_pending == pending_warn_threshold:
self.log.warning(
"Driver pod %s has been Pending for %d polls (~%ds); "
"it may be unschedulable. Continuing to wait — set execution_timeout to bound wait time.",
pod_name,
consecutive_pending,
consecutive_pending * poll_interval,
)
else:
consecutive_pending = 0

if phase == "Unknown":
consecutive_unknown += 1
if consecutive_unknown >= max_consecutive_unknown:
raise RuntimeError(
f"Spark application {app_id} reported Unknown phase "
f"{consecutive_unknown} times consecutively; giving up."
)
else:
consecutive_unknown = 0
time.sleep(poll_interval)
self._delete_driver_pod()
finally:
self._run_post_submit_commands()

def _build_spark_driver_kill_command(self) -> list[str]:
"""
Construct the spark-submit command to kill a driver.
Expand All @@ -1067,6 +1220,25 @@ def _build_spark_driver_kill_command(self) -> list[str]:

return connection_cmd

def _delete_driver_pod(self) -> None:
"""Delete the Kubernetes driver pod, logging a warning on failure."""
import kubernetes

self.log.info("Deleting driver pod %s on Kubernetes", self._kubernetes_driver_pod)
try:
client = kube_client.get_kube_client()
client.delete_namespaced_pod(
self._kubernetes_driver_pod,
self._connection["namespace"],
body=kubernetes.client.V1DeleteOptions(),
pretty=True,
)
self.log.info("Deleted driver pod %s", self._kubernetes_driver_pod)
except kube_client.ApiException:
self.log.exception(
"Exception when attempting to delete driver pod %s", self._kubernetes_driver_pod
)

def on_kill(self) -> None:
"""Kill Spark submit command."""
self.log.debug("Kill Command is being called")
Expand All @@ -1080,6 +1252,11 @@ def on_kill(self) -> None:
"Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait()
)

if self._should_track_driver_via_k8s_api() and self._kubernetes_driver_pod:
# spark-submit exits early under waitAppCompletion=false, so _submit_sp.poll() is
# not None during the poll loop — the deletion block below is skipped on kill.
self._delete_driver_pod()

if self._submit_sp and self._submit_sp.poll() is None:
self.log.info("Sending kill signal to %s", self._connection["spark_binary"])
self._submit_sp.kill()
Expand Down Expand Up @@ -1111,24 +1288,7 @@ def on_kill(self) -> None:
self.log.info("YARN app killed with return code: %s", yarn_kill.wait())

if self._kubernetes_driver_pod:
self.log.info("Killing pod %s on Kubernetes", self._kubernetes_driver_pod)

# Currently only instantiate Kubernetes client for killing a spark pod.
try:
import kubernetes

client = kube_client.get_kube_client()
api_response = client.delete_namespaced_pod(
self._kubernetes_driver_pod,
self._connection["namespace"],
body=kubernetes.client.V1DeleteOptions(),
pretty=True,
)

self.log.info("Spark on K8s killed with response: %s", api_response)

except kube_client.ApiException:
self.log.exception("Exception when attempting to kill Spark on K8s")
self._delete_driver_pod()

# Opt-in REST kill path — uses the same RM endpoint as polling, no
# `yarn` CLI dependency on the worker. Independent of `_submit_sp`
Expand Down
Loading
Loading