Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions airflow/providers/apache/spark/hooks/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
(will overwrite any spark_binary defined in the connection's extra JSON)
:param properties_file: Path to a file from which to load extra properties. If not
specified, this will look for conf/spark-defaults.conf.
:param queue: The name of the YARN queue to which the application is submitted.
:param yarn_queue: The name of the YARN queue to which the application is submitted.
(will overwrite any yarn queue defined in the connection's extra JSON)
:param deploy_mode: Whether to deploy your driver on the worker nodes (cluster) or locally as an client.
:param deploy_mode: Whether to deploy your driver on the worker nodes (cluster) or locally as an client.
(will overwrite any deployment mode defined in the connection's extra JSON)
:param use_krb5ccache: if True, configure spark to use ticket cache instead of relying
on keytab for Kerberos login
Expand Down Expand Up @@ -165,7 +165,7 @@ def __init__(
verbose: bool = False,
spark_binary: str | None = None,
properties_file: str | None = None,
queue: str | None = None,
yarn_queue: str | None = None,
deploy_mode: str | None = None,
*,
use_krb5ccache: bool = False,
Expand Down Expand Up @@ -201,7 +201,7 @@ def __init__(
self._kubernetes_driver_pod: str | None = None
self.spark_binary = spark_binary
self._properties_file = properties_file
self._queue = queue
self._yarn_queue = yarn_queue
self._deploy_mode = deploy_mode
self._connection = self._resolve_connection()
self._is_yarn = "yarn" in self._connection["master"]
Expand Down Expand Up @@ -231,7 +231,7 @@ def _resolve_connection(self) -> dict[str, Any]:
# Build from connection master or default to yarn if not available
conn_data = {
"master": "yarn",
"queue": None,
"queue": None, # yarn queue
"deploy_mode": None,
"spark_binary": self.spark_binary or DEFAULT_SPARK_BINARY,
"namespace": None,
Expand All @@ -248,7 +248,7 @@ def _resolve_connection(self) -> dict[str, Any]:

# Determine optional yarn queue from the extra field
extra = conn.extra_dejson
conn_data["queue"] = self._queue if self._queue else extra.get("queue")
conn_data["queue"] = self._yarn_queue if self._yarn_queue else extra.get("queue")
conn_data["deploy_mode"] = self._deploy_mode if self._deploy_mode else extra.get("deploy-mode")
if not self.spark_binary:
self.spark_binary = extra.get("spark-binary", DEFAULT_SPARK_BINARY)
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/apache/spark/operators/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class SparkSubmitOperator(BaseOperator):
(will overwrite any spark_binary defined in the connection's extra JSON)
:param properties_file: Path to a file from which to load extra properties. If not
specified, this will look for conf/spark-defaults.conf.
:param queue: The name of the YARN queue to which the application is submitted.
:param yarn_queue: The name of the YARN queue to which the application is submitted.
(will overwrite any yarn queue defined in the connection's extra JSON)
:param deploy_mode: Whether to deploy your driver on the worker nodes (cluster) or locally as a client.
(will overwrite any deployment mode defined in the connection's extra JSON)
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(
verbose: bool = False,
spark_binary: str | None = None,
properties_file: str | None = None,
queue: str | None = None,
yarn_queue: str | None = None,
deploy_mode: str | None = None,
use_krb5ccache: bool = False,
**kwargs: Any,
Expand Down Expand Up @@ -161,7 +161,7 @@ def __init__(
self._verbose = verbose
self._spark_binary = spark_binary
self.properties_file = properties_file
self._queue = queue
self._yarn_queue = yarn_queue
self._deploy_mode = deploy_mode
self._hook: SparkSubmitHook | None = None
self._conn_id = conn_id
Expand Down Expand Up @@ -206,7 +206,7 @@ def _get_hook(self) -> SparkSubmitHook:
verbose=self._verbose,
spark_binary=self._spark_binary,
properties_file=self.properties_file,
queue=self._queue,
yarn_queue=self._yarn_queue,
deploy_mode=self._deploy_mode,
use_krb5ccache=self._use_krb5ccache,
)
20 changes: 13 additions & 7 deletions tests/providers/apache/spark/operators/test_spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ class TestSparkSubmitOperator:
"args should keep embedded spaces",
],
"use_krb5ccache": True,
"queue": "yarn_dev_queue2",
"yarn_queue": "yarn_dev_queue2",
"deploy_mode": "client2",
"queue": "airflow_custom_queue",
}

def setup_method(self):
Expand Down Expand Up @@ -122,10 +123,11 @@ def test_execute(self):
"args should keep embedded spaces",
],
"spark_binary": "sparky",
"queue": "yarn_dev_queue2",
"yarn_queue": "yarn_dev_queue2",
"deploy_mode": "client2",
"use_krb5ccache": True,
"properties_file": "conf/spark-custom.conf",
"queue": "airflow_custom_queue",
}

assert conn_id == operator._conn_id
Expand Down Expand Up @@ -153,10 +155,11 @@ def test_execute(self):
assert expected_dict["driver_memory"] == operator._driver_memory
assert expected_dict["application_args"] == operator.application_args
assert expected_dict["spark_binary"] == operator._spark_binary
assert expected_dict["queue"] == operator._queue
assert expected_dict["deploy_mode"] == operator._deploy_mode
assert expected_dict["properties_file"] == operator.properties_file
assert expected_dict["use_krb5ccache"] == operator._use_krb5ccache
assert expected_dict["queue"] == operator.queue
assert expected_dict["yarn_queue"] == operator._yarn_queue

@pytest.mark.db_test
def test_spark_submit_cmd_connection_overrides(self):
Expand All @@ -168,18 +171,21 @@ def test_spark_submit_cmd_connection_overrides(self):
task_id="spark_submit_job", spark_binary="sparky", dag=self.dag, **config
)
cmd = " ".join(operator._get_hook()._build_spark_submit_command("test"))
assert "--queue yarn_dev_queue2" in cmd
assert "--queue yarn_dev_queue2" in cmd # yarn queue
assert "--deploy-mode client2" in cmd
assert "sparky" in cmd
assert operator.queue == "airflow_custom_queue" # airflow queue

# if we don't pass any overrides in arguments
config["queue"] = None
# if we don't pass any overrides in arguments, default values
config["yarn_queue"] = None
config["deploy_mode"] = None
config.pop("queue", None) # using default airflow queue
operator2 = SparkSubmitOperator(task_id="spark_submit_job2", dag=self.dag, **config)
cmd2 = " ".join(operator2._get_hook()._build_spark_submit_command("test"))
assert "--queue root.default" in cmd2
assert "--queue root.default" in cmd2 # yarn queue
assert "--deploy-mode client2" not in cmd2
assert "spark-submit" in cmd2
assert operator2.queue == "default" # airflow queue

@pytest.mark.db_test
def test_render_template(self):
Expand Down