diff --git a/airflow-core/src/airflow/dag_processing/dagbag.py b/airflow-core/src/airflow/dag_processing/dagbag.py index 34684810977c7..15095e458d368 100644 --- a/airflow-core/src/airflow/dag_processing/dagbag.py +++ b/airflow-core/src/airflow/dag_processing/dagbag.py @@ -337,6 +337,10 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): try: if dag.fileloc is None: dag.fileloc = filepath + + # Add the bundle_name to the DAG + dag.bundle_name = self.bundle_name + # Validate before adding to bag (matches original _process_modules behavior) dag.validate() _validate_executor_fields(dag, self.bundle_name) diff --git a/airflow-core/src/airflow/models/serialized_dag.py b/airflow-core/src/airflow/models/serialized_dag.py index ce0333653114e..fd773e8ed4fe3 100644 --- a/airflow-core/src/airflow/models/serialized_dag.py +++ b/airflow-core/src/airflow/models/serialized_dag.py @@ -376,6 +376,7 @@ def hash(cls, dag_data): # bundle_path and relative fileloc more correctly determines the # dag file location. data_["dag"].pop("fileloc", None) + data_["dag"].pop("bundle_name", None) data_json = json.dumps(data_, sort_keys=True).encode("utf-8") return md5(data_json).hexdigest() diff --git a/airflow-core/src/airflow/serialization/definitions/dag.py b/airflow-core/src/airflow/serialization/definitions/dag.py index 61205832ac5f5..abdc25bc4fd8b 100644 --- a/airflow-core/src/airflow/serialization/definitions/dag.py +++ b/airflow-core/src/airflow/serialization/definitions/dag.py @@ -116,6 +116,7 @@ class SerializedDAG: rerun_with_latest_version: bool | None = None doc_md: str | None = None edge_info: dict[str, dict[str, EdgeInfoType]] = attrs.field(factory=dict) + bundle_name: str | None = None end_date: datetime.datetime | None = None fail_fast: bool = False has_on_failure_callback: bool = False @@ -181,6 +182,7 @@ def get_serialized_fields(cls) -> frozenset[str]: "task_group", "timetable", "timezone", + "bundle_name", } ) diff --git a/airflow-core/src/airflow/serialization/schema.json b/airflow-core/src/airflow/serialization/schema.json index b9efe88448039..872c3a1331ee3 100644 --- a/airflow-core/src/airflow/serialization/schema.json +++ b/airflow-core/src/airflow/serialization/schema.json @@ -185,6 +185,7 @@ "fail_fast": { "type": "boolean", "default": false }, "fileloc": { "type" : "string"}, "relative_fileloc": { "type" : "string"}, + "bundle_name": { "anyOf": [{ "type": "null" }, { "type": "string" }] }, "_processor_dags_folder": { "anyOf": [ { "type": "null" }, diff --git a/airflow-core/tests/unit/dag_processing/test_dagbag.py b/airflow-core/tests/unit/dag_processing/test_dagbag.py index 99abd92a59f73..bd330fbc176dc 100644 --- a/airflow-core/tests/unit/dag_processing/test_dagbag.py +++ b/airflow-core/tests/unit/dag_processing/test_dagbag.py @@ -360,6 +360,16 @@ def test_dagbag_with_bundle_name(self, tmp_path): dagbag2 = DagBag(dag_folder=os.fspath(tmp_path)) assert dagbag2.bundle_name is None + def test_dag_with_bundle_name(self, tmp_path): + """Test that bundle_name is attached to each Dag in the DagBag.""" + dagbag = DagBag(dag_folder=os.fspath(tmp_path), bundle_name="test_bundle") + for dag in dagbag.dags.values(): + assert dag.bundle_name == "test_bundle" + + dagbag2 = DagBag(dag_folder=os.fspath(tmp_path)) + for dag in dagbag2.dags.values(): + assert dag.bundle_name is None + def test_get_existing_dag(self, tmp_path, standard_example_dags_folder): """ Test that we're able to parse some example DAGs and retrieve them diff --git a/airflow-core/tests/unit/models/test_serialized_dag.py b/airflow-core/tests/unit/models/test_serialized_dag.py index a1ad63de5d688..03a4159eb2b70 100644 --- a/airflow-core/tests/unit/models/test_serialized_dag.py +++ b/airflow-core/tests/unit/models/test_serialized_dag.py @@ -719,6 +719,30 @@ def test_hash_method_removes_fileloc_and_remains_consistent(self): assert "fileloc" in test_data["dag"] assert test_data["dag"]["fileloc"] == "/different/path/to/dag.py" + def test_hash_method_removes_bundle_name_and_remains_consistent(self): + """Test that the hash method removes bundle_name before hashing.""" + test_data = { + "__version": 1, + "dag": { + "bundle_name": "bundle_a", + "dag_id": "test_dag", + "tasks": { + "task1": {"task_id": "task1"}, + }, + }, + } + + hash_with_bundle_name = SDM.hash(test_data) + + test_data["dag"]["bundle_name"] = "bundle_b" + + hash_with_different_bundle_name = SDM.hash(test_data) + + assert hash_with_bundle_name == hash_with_different_bundle_name + + # Verify original data is not mutated by hash() + assert test_data["dag"]["bundle_name"] == "bundle_b" + def test_hash_method_consistent_with_dict_ordering_in_template_fields(self, dag_maker): from airflow.sdk.bases.operator import BaseOperator diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 5bc5b0ae84768..4699415a04d72 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -2454,6 +2454,22 @@ def test_dag_rerun_with_latest_version_roundtrip(self, value, expected): deserialized_dag = DagSerialization.from_dict(serialized_dag) assert deserialized_dag.rerun_with_latest_version is expected + @pytest.mark.parametrize( + ("bundle_name", "expected"), + [ + ("my_bundle", "my_bundle"), + (None, None), + ], + ) + def test_dag_bundle_name_roundtrip(self, bundle_name, expected): + """Test that bundle_name survives serialization roundtrip.""" + dag = DAG(dag_id="test_dag_bundle_name_roundtrip", schedule=None) + BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019, 8, 1)) + dag.bundle_name = bundle_name + serialized_dag = DagSerialization.to_dict(dag) + deserialized_dag = DagSerialization.from_dict(serialized_dag) + assert deserialized_dag.bundle_name == expected + @pytest.mark.parametrize( ("object_to_serialized", "expected_output"), [ diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index fd47de467fea4..ad69f02dc9c31 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -542,6 +542,8 @@ def __rich_repr__(self): fileloc: str = attrs.field(init=False, factory=_default_fileloc) relative_fileloc: str | None = attrs.field(init=False, default=None) + bundle_name: str | None = attrs.field(init=False, default=None) + partial: bool = attrs.field(init=False, default=False) edge_info: dict[str, dict[str, EdgeInfoType]] = attrs.field(init=False, factory=dict) diff --git a/task-sdk/tests/task_sdk/definitions/test_dag.py b/task-sdk/tests/task_sdk/definitions/test_dag.py index c42eb8dfc80a1..7f3c6f2bf3b2c 100644 --- a/task-sdk/tests/task_sdk/definitions/test_dag.py +++ b/task-sdk/tests/task_sdk/definitions/test_dag.py @@ -703,6 +703,15 @@ def noop_pipeline(): ... assert dag.dag_id == "noop_pipeline" assert dag.fileloc == __file__ + def test_bundle_name_defaults_to_none(self): + dag = DAG("test_dag", schedule=None) + assert dag.bundle_name is None + + def test_bundle_name_can_be_set(self): + dag = DAG("test_dag", schedule=None) + dag.bundle_name = "my_bundle" + assert dag.bundle_name == "my_bundle" + def test_set_dag_id(self): """Test that checks you can set dag_id from decorator."""