Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions airflow-core/src/airflow/dag_processing/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True):
try:
if dag.fileloc is None:
dag.fileloc = filepath

# Add the bundle_name to the DAG

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Add the bundle_name to the DAG
# Add the bundle_name to the Dag

dag.bundle_name = self.bundle_name

# Validate before adding to bag (matches original _process_modules behavior)
dag.validate()
_validate_executor_fields(dag, self.bundle_name)
Expand Down
1 change: 1 addition & 0 deletions airflow-core/src/airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def hash(cls, dag_data):
# bundle_path and relative fileloc more correctly determines the
# dag file location.
data_["dag"].pop("fileloc", None)
data_["dag"].pop("bundle_name", None)
data_json = json.dumps(data_, sort_keys=True).encode("utf-8")
return md5(data_json).hexdigest()

Expand Down
2 changes: 2 additions & 0 deletions airflow-core/src/airflow/serialization/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class SerializedDAG:
rerun_with_latest_version: bool | None = None
doc_md: str | None = None
edge_info: dict[str, dict[str, EdgeInfoType]] = attrs.field(factory=dict)
bundle_name: str | None = None
end_date: datetime.datetime | None = None
fail_fast: bool = False
has_on_failure_callback: bool = False
Expand Down Expand Up @@ -181,6 +182,7 @@ def get_serialized_fields(cls) -> frozenset[str]:
"task_group",
"timetable",
"timezone",
"bundle_name",
}
)

Expand Down
1 change: 1 addition & 0 deletions airflow-core/src/airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
"fail_fast": { "type": "boolean", "default": false },
"fileloc": { "type" : "string"},
"relative_fileloc": { "type" : "string"},
"bundle_name": { "anyOf": [{ "type": "null" }, { "type": "string" }] },
"_processor_dags_folder": {
"anyOf": [
{ "type": "null" },
Expand Down
10 changes: 10 additions & 0 deletions airflow-core/tests/unit/dag_processing/test_dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,16 @@ def test_dagbag_with_bundle_name(self, tmp_path):
dagbag2 = DagBag(dag_folder=os.fspath(tmp_path))
assert dagbag2.bundle_name is None

def test_dag_with_bundle_name(self, tmp_path):
"""Test that bundle_name is attached to each Dag in the DagBag."""
dagbag = DagBag(dag_folder=os.fspath(tmp_path), bundle_name="test_bundle")
for dag in dagbag.dags.values():
assert dag.bundle_name == "test_bundle"

dagbag2 = DagBag(dag_folder=os.fspath(tmp_path))
for dag in dagbag2.dags.values():
assert dag.bundle_name is None

Comment on lines +363 to +372

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use a real path

def test_get_existing_dag(self, tmp_path, standard_example_dags_folder):
"""
Test that we're able to parse some example DAGs and retrieve them
Expand Down
24 changes: 24 additions & 0 deletions airflow-core/tests/unit/models/test_serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,30 @@ def test_hash_method_removes_fileloc_and_remains_consistent(self):
assert "fileloc" in test_data["dag"]
assert test_data["dag"]["fileloc"] == "/different/path/to/dag.py"

def test_hash_method_removes_bundle_name_and_remains_consistent(self):
"""Test that the hash method removes bundle_name before hashing."""
test_data = {
"__version": 1,
"dag": {
"bundle_name": "bundle_a",
"dag_id": "test_dag",
"tasks": {
"task1": {"task_id": "task1"},
},
},
}

hash_with_bundle_name = SDM.hash(test_data)

test_data["dag"]["bundle_name"] = "bundle_b"

hash_with_different_bundle_name = SDM.hash(test_data)

assert hash_with_bundle_name == hash_with_different_bundle_name

# Verify original data is not mutated by hash()
assert test_data["dag"]["bundle_name"] == "bundle_b"

def test_hash_method_consistent_with_dict_ordering_in_template_fields(self, dag_maker):
from airflow.sdk.bases.operator import BaseOperator

Expand Down
16 changes: 16 additions & 0 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2454,6 +2454,22 @@ def test_dag_rerun_with_latest_version_roundtrip(self, value, expected):
deserialized_dag = DagSerialization.from_dict(serialized_dag)
assert deserialized_dag.rerun_with_latest_version is expected

@pytest.mark.parametrize(
("bundle_name", "expected"),
[
("my_bundle", "my_bundle"),
(None, None),
],
)
def test_dag_bundle_name_roundtrip(self, bundle_name, expected):
"""Test that bundle_name survives serialization roundtrip."""
dag = DAG(dag_id="test_dag_bundle_name_roundtrip", schedule=None)
BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019, 8, 1))
dag.bundle_name = bundle_name
serialized_dag = DagSerialization.to_dict(dag)
deserialized_dag = DagSerialization.from_dict(serialized_dag)
assert deserialized_dag.bundle_name == expected

@pytest.mark.parametrize(
("object_to_serialized", "expected_output"),
[
Expand Down
2 changes: 2 additions & 0 deletions task-sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,8 @@ def __rich_repr__(self):

fileloc: str = attrs.field(init=False, factory=_default_fileloc)
relative_fileloc: str | None = attrs.field(init=False, default=None)
bundle_name: str | None = attrs.field(init=False, default=None)

partial: bool = attrs.field(init=False, default=False)

edge_info: dict[str, dict[str, EdgeInfoType]] = attrs.field(init=False, factory=dict)
Expand Down
9 changes: 9 additions & 0 deletions task-sdk/tests/task_sdk/definitions/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,15 @@ def noop_pipeline(): ...
assert dag.dag_id == "noop_pipeline"
assert dag.fileloc == __file__

def test_bundle_name_defaults_to_none(self):
dag = DAG("test_dag", schedule=None)
assert dag.bundle_name is None

def test_bundle_name_can_be_set(self):
dag = DAG("test_dag", schedule=None)
dag.bundle_name = "my_bundle"
assert dag.bundle_name == "my_bundle"

def test_set_dag_id(self):
"""Test that checks you can set dag_id from decorator."""

Expand Down
Loading