diff --git a/providers/docker/src/airflow/providers/docker/operators/docker.py b/providers/docker/src/airflow/providers/docker/operators/docker.py index 345a6c624f1a7..e81579a4f1c23 100644 --- a/providers/docker/src/airflow/providers/docker/operators/docker.py +++ b/providers/docker/src/airflow/providers/docker/operators/docker.py @@ -151,8 +151,12 @@ class DockerOperator(BaseOperator): The path is also made available via the environment variable ``AIRFLOW_TMP_DIR`` inside the container. :param user: Default user inside the docker container. - :param mounts: List of volumes to mount into the container. Each item should - be a :py:class:`docker.types.Mount` instance. (templated) + :param mounts: List of volumes to mount into the container. Each item may + be a :py:class:`docker.types.Mount` instance, or a ``dict`` of + :py:class:`~docker.types.Mount` keyword arguments (e.g. + ``{"target": "/data", "source": "vol", "type": "volume"}``); ``dict`` + entries are converted to ``Mount`` instances at construction time. + (templated) :param entrypoint: Overwrite the default ENTRYPOINT of the image :param working_dir: Working directory to set on the container (equivalent to the -w switch the docker client) @@ -245,7 +249,7 @@ def __init__( mount_tmp_dir: bool = True, tmp_dir: str = "/tmp/airflow", user: str | int | None = None, - mounts: list[Mount] | None = None, + mounts: list[Mount | dict] | None = None, entrypoint: str | list[str] | None = None, working_dir: str | None = None, xcom_all: bool = False, @@ -304,7 +308,8 @@ def __init__( self.mount_tmp_dir = mount_tmp_dir self.tmp_dir = tmp_dir self.user = user - self.mounts = mounts or [] + mounts = [mount if isinstance(mount, Mount) else Mount(**mount) for mount in (mounts or [])] + self.mounts: list[Mount] = mounts for mount in self.mounts: mount.template_fields = ("Source", "Target", "Type") self.entrypoint = entrypoint diff --git a/providers/docker/tests/unit/docker/operators/test_docker.py b/providers/docker/tests/unit/docker/operators/test_docker.py index d375bd577ce2d..a753894d48f8c 100644 --- a/providers/docker/tests/unit/docker/operators/test_docker.py +++ b/providers/docker/tests/unit/docker/operators/test_docker.py @@ -818,3 +818,33 @@ def test_basic_docker_operator_with_template_fields(self, create_task_instance_o rendered = ti.render_templates() assert rendered.container_name == f"python_{ti.dag_id}" assert rendered.mounts[0]["Target"] == f"/{ti.run_id}" + + def test_dict_mounts_are_normalized_to_mount_objects(self): + op = DockerOperator( + task_id="test", + image="test", + mounts=[ + {"target": "/data", "source": "workspace", "type": "volume", "read_only": False}, + Mount(target="/logs", source="logs", type="volume"), + ], + ) + assert all(isinstance(m, Mount) for m in op.mounts) + assert op.mounts[0]["Target"] == "/data" + assert op.mounts[0]["Source"] == "workspace" + assert op.mounts[0]["Type"] == "volume" + assert op.mounts[0]["ReadOnly"] is False + assert op.mounts[1]["Target"] == "/logs" + + @pytest.mark.db_test + def test_dict_mounts_are_templated(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + operator_class=DockerOperator, + dag_id="test", + task_id="test", + image="test", + mounts=[ + {"target": "/{{task_instance.run_id}}", "source": "workspace", "type": "volume"}, + ], + ) + rendered = ti.render_templates() + assert rendered.mounts[0]["Target"] == f"/{ti.run_id}"