diff --git a/airflow-core/src/airflow/dag_processing/bundles/base.py b/airflow-core/src/airflow/dag_processing/bundles/base.py index b6b55f9251cfe..344a3349fecab 100644 --- a/airflow-core/src/airflow/dag_processing/bundles/base.py +++ b/airflow-core/src/airflow/dag_processing/bundles/base.py @@ -292,6 +292,8 @@ class BaseDagBundle(ABC): :param refresh_interval: How often the bundle should be refreshed from the source in seconds (Optional - defaults to [dag_processor] refresh_interval) :param version: Version of the DAG bundle (Optional) + :param version_data: Structured metadata for this bundle version, e.g. an S3 manifest. + Only populated for pinned runs (where dag_run.bundle_version is not None). (Optional) """ supports_versioning: bool = False @@ -304,10 +306,12 @@ def __init__( name: str, refresh_interval: int = conf.getint("dag_processor", "refresh_interval"), version: str | None = None, + version_data: dict[str, Any] | None = None, view_url_template: str | None = None, ) -> None: self.name = name self.version = version + self.version_data = version_data self.refresh_interval = refresh_interval self.is_initialized: bool = False diff --git a/airflow-core/src/airflow/dag_processing/bundles/manager.py b/airflow-core/src/airflow/dag_processing/bundles/manager.py index 78c54266eda9f..7d966e0b5074f 100644 --- a/airflow-core/src/airflow/dag_processing/bundles/manager.py +++ b/airflow-core/src/airflow/dag_processing/bundles/manager.py @@ -20,7 +20,7 @@ import logging import os import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from itsdangerous import URLSafeSerializer from pydantic import BaseModel, ValidationError @@ -395,19 +395,24 @@ def _extract_template_params(bundle_instance: BaseDagBundle) -> dict: return params - def get_bundle(self, name: str, version: str | None = None) -> BaseDagBundle: + def get_bundle( + self, name: str, version: str | None = None, version_data: dict[str, Any] | None = None + ) -> BaseDagBundle: """ Get a DAG bundle by name. :param name: The name of the DAG bundle. :param version: The version of the DAG bundle you need (optional). If not provided, ``tracking_ref`` will be used instead. + :param version_data: Optional structured data associated with this version (e.g., S3 manifest). :return: The DAG bundle. """ cfg_bundle = self._bundle_config.get(name) if not cfg_bundle: raise ValueError(f"Requested bundle '{name}' is not configured.") - return cfg_bundle.bundle_class(name=name, version=version, **cfg_bundle.kwargs) + return cfg_bundle.bundle_class( + name=name, version=version, version_data=version_data, **cfg_bundle.kwargs + ) def get_all_dag_bundles(self) -> Iterable[BaseDagBundle]: """ diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index 503cab7b3965a..41334d68f3038 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -21,7 +21,7 @@ import os from abc import ABC, abstractmethod from collections.abc import Hashable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, Field @@ -66,6 +66,13 @@ class BundleInfo(BaseModel): name: str version: str | None = None + version_data: dict[str, Any] | None = None + """Optional structured metadata for this bundle version (e.g., an S3 object manifest). + + This field is serialized on every workload payload — executor command-line argv for + K8s/ECS/Batch/Lambda, message body for Celery/SQS. Keep payloads small to avoid hitting + transport limits (ARG_MAX is ~128 KB on Linux; the etcd PodSpec ceiling is ~1.5 MB). + """ class BaseWorkloadSchema(BaseModel): diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py index b4bf02ea47b8d..611457f88a957 100644 --- a/airflow-core/src/airflow/executors/workloads/task.py +++ b/airflow-core/src/airflow/executors/workloads/task.py @@ -102,9 +102,13 @@ def make( ser_ti = TaskInstanceDTO.model_validate(ti, from_attributes=True) if not bundle_info: + version_data = None + if ti.dag_version is not None and ti.dag_run.bundle_version is not None: + version_data = ti.dag_version.version_data bundle_info = BundleInfo( name=ti.dag_model.bundle_name, version=ti.dag_run.bundle_version, + version_data=version_data, ) fname = log_filename_template_renderer()(ti=ti) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 2f94d480eb6d1..df2abe2aabd06 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -684,6 +684,12 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) - ranked_query.c.map_index_for_ordering, ) .options(selectinload(TI.dag_model)) + # Eager-load dag_version: TIs become transient (via make_transient) before + # ExecuteTask.make() reads ti.dag_version.version_data. Lazy loads on + # transient objects silently return None instead of raising DetachedInstanceError. + # Scope the second SELECT to version_data (the PK is auto-included) so we read + # two columns rather than the full DagVersion row. + .options(selectinload(TI.dag_version).load_only(DagVersion.version_data)) ) query = query.limit(max_tis) diff --git a/airflow-core/tests/unit/dag_processing/bundles/test_base.py b/airflow-core/tests/unit/dag_processing/bundles/test_base.py index 6fc7ba39a0a12..f092f3e00e770 100644 --- a/airflow-core/tests/unit/dag_processing/bundles/test_base.py +++ b/airflow-core/tests/unit/dag_processing/bundles/test_base.py @@ -323,3 +323,16 @@ def test_bundle_version_inequality(self): bv1 = BundleVersion(version="abc", data={"key": "val"}) bv2 = BundleVersion(version="abc", data={"key": "other"}) assert bv1 != bv2 + + +def test_version_data_stored_on_bundle(): + """Test that version_data passed to a bundle constructor is stored on the instance.""" + manifest = {"schema_version": 1, "files": {"dags/my_dag.py": "S3VersionId123"}} + bundle = BasicBundle(name="test", version="abc", version_data=manifest) + assert bundle.version_data == manifest + + +def test_version_data_defaults_to_none(): + """Test that version_data defaults to None when not provided.""" + bundle = BasicBundle(name="test") + assert bundle.version_data is None diff --git a/airflow-core/tests/unit/executors/test_workloads.py b/airflow-core/tests/unit/executors/test_workloads.py index a063b91140dfb..63a249a36fd1b 100644 --- a/airflow-core/tests/unit/executors/test_workloads.py +++ b/airflow-core/tests/unit/executors/test_workloads.py @@ -171,3 +171,82 @@ def test_workload_ti_round_trips_through_sdk_generated_model(): assert received.queue == "jdk-17" assert received.map_index == 3 assert not hasattr(received, "pool_slots") + + +class TestExecuteTaskMakeVersionData: + """Tests for ExecuteTask.make() threading version_data through BundleInfo.""" + + @pytest.fixture(autouse=True) + def _stub_log_template(self, monkeypatch): + monkeypatch.setattr( + "airflow.utils.helpers.log_filename_template_renderer", + lambda: lambda **kwargs: "test.log", + ) + + @staticmethod + def _make_mock_ti(bundle_version, version_data, *, has_dag_version=True): + """Build a mock TI with the attributes ExecuteTask.make() reads. + + ``has_dag_version`` controls whether the TI has an associated DagVersion + (legacy/backfilled TIs may not), independently of ``version_data`` so the + pin-guard can be exercised with version_data present on an unpinned run. + """ + from unittest.mock import Mock + + ti = Mock() + ti.id = uuid4() + ti.dag_version_id = uuid4() + ti.task_id = "test_task" + ti.dag_id = "test_dag" + ti.run_id = "test_run" + ti.try_number = 1 + ti.map_index = -1 + ti.pool_slots = 1 + ti.queue = "default" + ti.priority_weight = 1 + ti.executor_config = None + ti.parent_context_carrier = None + ti.context_carrier = None + ti.hostname = None + ti.external_executor_id = None + + ti.dag_model.bundle_name = "test-bundle" + ti.dag_model.relative_fileloc = "dags/test_dag.py" + + ti.dag_run.bundle_version = bundle_version + + if has_dag_version: + ti.dag_version.version_data = version_data + else: + ti.dag_version = None + + return ti + + def test_pinned_run_populates_version_data(self): + """When the run is pinned, version_data from dag_version flows to BundleInfo.""" + version_data = {"schema_version": 1, "files": {"dags/my_dag.py": "ver123"}} + ti = self._make_mock_ti(bundle_version="abc123", version_data=version_data) + + workload = ExecuteTask.make(ti) + + assert workload.bundle_info.version == "abc123" + assert workload.bundle_info.version_data == version_data + + def test_unpinned_run_suppresses_present_version_data(self): + """An unpinned run must not expose version_data even when the dag_version carries it.""" + version_data = {"schema_version": 1, "files": {"dags/my_dag.py": "ver123"}} + ti = self._make_mock_ti(bundle_version=None, version_data=version_data) + + workload = ExecuteTask.make(ti) + + assert workload.bundle_info.version is None + assert workload.bundle_info.version_data is None + + def test_missing_dag_version_yields_none(self): + """A pinned run whose TI has no dag_version (legacy/backfilled) yields no version_data.""" + ti = self._make_mock_ti(bundle_version="abc123", version_data=None, has_dag_version=False) + + workload = ExecuteTask.make(ti) + + assert workload.bundle_info.version == "abc123" + assert workload.bundle_info.version_data is None diff --git a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py index cf929e48d148d..211ed64acf604 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py @@ -814,6 +814,7 @@ def test_try_adopt_task_instances(self, mock_executor): task.dag_model = mock.Mock() task.dag_model.bundle_name = "test_bundle" task.dag_model.relative_fileloc = "test_dag.py" + task.dag_version = mock.Mock(version_data=None) task.dag_run = mock.Mock() task.dag_run.bundle_version = "1.0.0" task.dag_run.context_carrier = {} diff --git a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py index ca6fab54fa2c3..de5a17acf9666 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py @@ -1309,6 +1309,7 @@ def test_try_adopt_task_instances(self, mock_executor): task.dag_model = mock.Mock() task.dag_model.bundle_name = "test_bundle" task.dag_model.relative_fileloc = "test_dag.py" + task.dag_version = mock.Mock(version_data=None) task.dag_run = mock.Mock() task.dag_run.bundle_version = "1.0.0" task.dag_run.context_carrier = {} diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml index 7a7fb194fb03d..4bd6a68af4326 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml @@ -965,6 +965,12 @@ components: - type: string - type: 'null' title: Version + version_data: + anyOf: + - additionalProperties: true + type: object + - type: 'null' + title: Version Data type: object required: - name diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index bc03569386d14..139632fc0a037 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -570,6 +570,7 @@ class BundleInfo(BaseModel): name: Annotated[str, Field(title="Name")] version: Annotated[str | None, Field(title="Version")] = None + version_data: Annotated[dict[str, Any] | None, Field(title="Version Data")] = None class TerminalTIState(str, Enum): diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 12f86dec36bff..9830771701293 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -25,7 +25,7 @@ from importlib import import_module from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from typing import TYPE_CHECKING, Annotated, BinaryIO, ClassVar, Protocol +from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, ClassVar, Protocol from uuid import UUID import attrs @@ -67,6 +67,7 @@ class _BundleInfoLike(Protocol): name: str version: str | None + version_data: dict[str, Any] | None __all__ = ["CallbackSubprocess", "supervise_callback"] @@ -227,6 +228,7 @@ def _target(): bundle = DagBundlesManager().get_bundle( name=bundle_info.name, version=bundle_info.version, + version_data=bundle_info.version_data, ) bundle.initialize() if (bundle_path := str(bundle.path)) not in sys.path: diff --git a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json index 087d149d1af03..ae736e5d6621e 100644 --- a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json +++ b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json @@ -434,6 +434,19 @@ ], "default": null, "title": "Version" + }, + "version_data": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Version Data" } }, "required": [ diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index f3fee689928a0..ac8889daf0286 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1001,6 +1001,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: bundle_instance = DagBundlesManager().get_bundle( name=bundle_info.name, version=bundle_info.version, + version_data=bundle_info.version_data, ) bundle_instance.initialize() _verify_bundle_access(bundle_instance, log) diff --git a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py index c33299b313c81..2c7a7ce0f6bb1 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py @@ -494,6 +494,7 @@ def test_execute_callback_with_bundle_info_should_pass_correct_parameters( mock_bundle_setup["manager"].get_bundle.assert_called_once_with( name=bundle_info.name, version=bundle_info.version, + version_data=bundle_info.version_data, ) mock_bundle_setup["bundle"].initialize.assert_called_once()