Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,12 @@ class DockerOperator(BaseOperator):
The path is also made available via the environment variable
``AIRFLOW_TMP_DIR`` inside the container.
:param user: Default user inside the docker container.
:param mounts: List of volumes to mount into the container. Each item should
be a :py:class:`docker.types.Mount` instance. (templated)
:param mounts: List of volumes to mount into the container. Each item may
be a :py:class:`docker.types.Mount` instance, or a ``dict`` of
:py:class:`~docker.types.Mount` keyword arguments (e.g.
``{"target": "/data", "source": "vol", "type": "volume"}``); ``dict``
entries are converted to ``Mount`` instances at construction time.
(templated)
:param entrypoint: Overwrite the default ENTRYPOINT of the image
:param working_dir: Working directory to
set on the container (equivalent to the -w switch the docker client)
Expand Down Expand Up @@ -245,7 +249,7 @@ def __init__(
mount_tmp_dir: bool = True,
tmp_dir: str = "/tmp/airflow",
user: str | int | None = None,
mounts: list[Mount] | None = None,
mounts: list[Mount | dict] | None = None,
entrypoint: str | list[str] | None = None,
working_dir: str | None = None,
xcom_all: bool = False,
Expand Down Expand Up @@ -304,7 +308,8 @@ def __init__(
self.mount_tmp_dir = mount_tmp_dir
self.tmp_dir = tmp_dir
self.user = user
self.mounts = mounts or []
mounts = [mount if isinstance(mount, Mount) else Mount(**mount) for mount in (mounts or [])]
self.mounts: list[Mount] = mounts
for mount in self.mounts:
mount.template_fields = ("Source", "Target", "Type")
self.entrypoint = entrypoint
Expand Down
30 changes: 30 additions & 0 deletions providers/docker/tests/unit/docker/operators/test_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,3 +818,33 @@ def test_basic_docker_operator_with_template_fields(self, create_task_instance_o
rendered = ti.render_templates()
assert rendered.container_name == f"python_{ti.dag_id}"
assert rendered.mounts[0]["Target"] == f"/{ti.run_id}"

def test_dict_mounts_are_normalized_to_mount_objects(self):
op = DockerOperator(
task_id="test",
image="test",
mounts=[
{"target": "/data", "source": "workspace", "type": "volume", "read_only": False},
Mount(target="/logs", source="logs", type="volume"),
],
)
assert all(isinstance(m, Mount) for m in op.mounts)
assert op.mounts[0]["Target"] == "/data"
assert op.mounts[0]["Source"] == "workspace"
assert op.mounts[0]["Type"] == "volume"
assert op.mounts[0]["ReadOnly"] is False
assert op.mounts[1]["Target"] == "/logs"

@pytest.mark.db_test
def test_dict_mounts_are_templated(self, create_task_instance_of_operator):
ti = create_task_instance_of_operator(
operator_class=DockerOperator,
dag_id="test",
task_id="test",
image="test",
mounts=[
{"target": "/{{task_instance.run_id}}", "source": "workspace", "type": "volume"},
],
)
rendered = ti.render_templates()
assert rendered.mounts[0]["Target"] == f"/{ti.run_id}"
Loading