From ffb40abe790db67aace1c17044a4e279beeef2aa Mon Sep 17 00:00:00 2001 From: Hossein Torabi Date: Tue, 12 Jul 2022 15:28:12 +0200 Subject: [PATCH] set default task group in dag.add_task method Signed-off-by: Hossein Torabi --- airflow/models/dag.py | 19 ++++++++++++++----- airflow/models/taskmixin.py | 6 +----- tests/models/test_dag.py | 14 ++++++++++++++ 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 596f231721ec5..a85c98d96fc8e 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2215,6 +2215,8 @@ def add_task(self, task: Operator) -> None: :param task: the task you want to add """ + from airflow.utils.task_group import TaskGroupContext + if not self.start_date and not task.start_date: raise AirflowException("DAG is missing the start_date parameter") # if the task has no start date, assign it the same as the DAG @@ -2233,15 +2235,22 @@ def add_task(self, task: Operator) -> None: elif task.end_date and self.end_date: task.end_date = min(task.end_date, self.end_date) + task_id = task.task_id + if not task.task_group: + task_group = TaskGroupContext.get_current_task_group(self) + if task_group: + task_id = task_group.child_id(task_id) + task_group.add(task) + if ( - task.task_id in self.task_dict and self.task_dict[task.task_id] is not task - ) or task.task_id in self._task_group.used_group_ids: - raise DuplicateTaskIdFound(f"Task id '{task.task_id}' has already been added to the DAG") + task_id in self.task_dict and self.task_dict[task_id] is not task + ) or task_id in self._task_group.used_group_ids: + raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG") else: - self.task_dict[task.task_id] = task + self.task_dict[task_id] = task task.dag = self # Add task_id to used_group_ids to prevent group_id and task_id collisions. - self._task_group.used_group_ids.add(task.task_id) + self._task_group.used_group_ids.add(task_id) self.task_count = len(self.task_dict) diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index 06494946a8f5c..7a70e328d2c6c 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -195,10 +195,8 @@ def _set_relatives( ) if not self.has_dag(): - # If this task does not yet have a dag, add it to the same dag as the other task and - # put it in the dag's root TaskGroup. + # If this task does not yet have a dag, add it to the same dag as the other task. self.dag = dag - self.dag.task_group.add(self) def add_only_new(obj, item_set: Set[str], item: str) -> None: """Adds only new items to item set""" @@ -210,9 +208,7 @@ def add_only_new(obj, item_set: Set[str], item: str) -> None: for task in task_list: if dag and not task.has_dag(): # If the other task does not yet have a dag, add it to the same dag as this task and - # put it in the dag's root TaskGroup. dag.add_task(task) - dag.task_group.add(task) if upstream: add_only_new(task, task.downstream_task_ids, self.node_id) add_only_new(self, self.upstream_task_ids, task.node_id) diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 458c8af73e123..b88ec6284cfda 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -61,6 +61,7 @@ from airflow.utils.file import list_py_file_paths from airflow.utils.session import create_session, provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState +from airflow.utils.task_group import TaskGroup, TaskGroupContext from airflow.utils.timezone import datetime as datetime_tz from airflow.utils.types import DagRunType from airflow.utils.weight_rule import WeightRule @@ -1403,6 +1404,19 @@ def test_create_dagrun_job_id_is_set(self): ) assert dr.creating_job_id == job_id + def test_dag_add_task_sets_default_task_group(self): + dag = DAG(dag_id="test_dag_add_task_sets_default_task_group", start_date=DEFAULT_DATE) + task_without_task_group = EmptyOperator(task_id="task_without_group_id") + default_task_group = TaskGroupContext.get_current_task_group(dag) + dag.add_task(task_without_task_group) + assert default_task_group.get_child_by_label("task_without_group_id") == task_without_task_group + + task_group = TaskGroup(group_id="task_group", dag=dag) + task_with_task_group = EmptyOperator(task_id="task_with_task_group", task_group=task_group) + dag.add_task(task_with_task_group) + assert task_group.get_child_by_label("task_with_task_group") == task_with_task_group + assert dag.get_task("task_group.task_with_task_group") == task_with_task_group + @parameterized.expand( [ (State.QUEUED,),