Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2215,6 +2215,8 @@ def add_task(self, task: Operator) -> None:

:param task: the task you want to add
"""
from airflow.utils.task_group import TaskGroupContext

if not self.start_date and not task.start_date:
raise AirflowException("DAG is missing the start_date parameter")
# if the task has no start date, assign it the same as the DAG
Expand All @@ -2233,15 +2235,22 @@ def add_task(self, task: Operator) -> None:
elif task.end_date and self.end_date:
task.end_date = min(task.end_date, self.end_date)

task_id = task.task_id
if not task.task_group:
task_group = TaskGroupContext.get_current_task_group(self)
if task_group:
task_id = task_group.child_id(task_id)
task_group.add(task)

if (
task.task_id in self.task_dict and self.task_dict[task.task_id] is not task
) or task.task_id in self._task_group.used_group_ids:
raise DuplicateTaskIdFound(f"Task id '{task.task_id}' has already been added to the DAG")
task_id in self.task_dict and self.task_dict[task_id] is not task
) or task_id in self._task_group.used_group_ids:
raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG")
else:
self.task_dict[task.task_id] = task
self.task_dict[task_id] = task
task.dag = self
# Add task_id to used_group_ids to prevent group_id and task_id collisions.
self._task_group.used_group_ids.add(task.task_id)
self._task_group.used_group_ids.add(task_id)

self.task_count = len(self.task_dict)

Expand Down
6 changes: 1 addition & 5 deletions airflow/models/taskmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,8 @@ def _set_relatives(
)

if not self.has_dag():
# If this task does not yet have a dag, add it to the same dag as the other task and
# put it in the dag's root TaskGroup.
# If this task does not yet have a dag, add it to the same dag as the other task.
self.dag = dag
self.dag.task_group.add(self)

def add_only_new(obj, item_set: Set[str], item: str) -> None:
"""Adds only new items to item set"""
Expand All @@ -210,9 +208,7 @@ def add_only_new(obj, item_set: Set[str], item: str) -> None:
for task in task_list:
if dag and not task.has_dag():
# If the other task does not yet have a dag, add it to the same dag as this task and
# put it in the dag's root TaskGroup.
Comment thread
blcksrx marked this conversation as resolved.
Outdated
dag.add_task(task)
dag.task_group.add(task)
if upstream:
add_only_new(task, task.downstream_task_ids, self.node_id)
add_only_new(self, self.upstream_task_ids, task.node_id)
Expand Down
14 changes: 14 additions & 0 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from airflow.utils.file import list_py_file_paths
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.task_group import TaskGroup, TaskGroupContext
from airflow.utils.timezone import datetime as datetime_tz
from airflow.utils.types import DagRunType
from airflow.utils.weight_rule import WeightRule
Expand Down Expand Up @@ -1403,6 +1404,19 @@ def test_create_dagrun_job_id_is_set(self):
)
assert dr.creating_job_id == job_id

def test_dag_add_task_sets_default_task_group(self):
dag = DAG(dag_id="test_dag_add_task_sets_default_task_group", start_date=DEFAULT_DATE)
task_without_task_group = EmptyOperator(task_id="task_without_group_id")
default_task_group = TaskGroupContext.get_current_task_group(dag)
dag.add_task(task_without_task_group)
assert default_task_group.get_child_by_label("task_without_group_id") == task_without_task_group

task_group = TaskGroup(group_id="task_group", dag=dag)
task_with_task_group = EmptyOperator(task_id="task_with_task_group", task_group=task_group)
dag.add_task(task_with_task_group)
assert task_group.get_child_by_label("task_with_task_group") == task_with_task_group
assert dag.get_task("task_group.task_with_task_group") == task_with_task_group

@parameterized.expand(
[
(State.QUEUED,),
Expand Down