From 08584b70a5020c289e044b86c83a0277d7948b94 Mon Sep 17 00:00:00 2001 From: Chao-Han Tsai Date: Tue, 14 Jul 2020 15:51:27 -0700 Subject: [PATCH 1/5] Change DAG.clear to take dag_run_state --- airflow/cli/commands/dag_command.py | 3 +- airflow/models/dag.py | 37 ++++++------ .../cloud/example_dags/example_datafusion.py | 3 +- .../google/cloud/example_dags/example_gcs.py | 3 +- .../example_dags/example_campaign_manager.py | 3 +- tests/models/test_dag.py | 59 ++++++++++++++++--- 6 files changed, 78 insertions(+), 30 deletions(-) diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py index 685ca21c782c6..8a0f162f39094 100644 --- a/airflow/cli/commands/dag_command.py +++ b/airflow/cli/commands/dag_command.py @@ -39,6 +39,7 @@ from airflow.utils.cli import get_dag, get_dag_by_file_location, process_subdir, sigint_handler from airflow.utils.dot_renderer import render_dag from airflow.utils.session import create_session, provide_session +from airflow.utils.state import State def _tabulate_dag_runs(dag_runs: List[DagRun], tablefmt: str = "fancy_grid") -> str: @@ -381,7 +382,7 @@ def dag_list_dag_runs(args, dag=None): def dag_test(args, session=None): """Execute one single DagRun for a given DAG and execution date, using the DebugExecutor.""" dag = get_dag(subdir=args.subdir, dag_id=args.dag_id) - dag.clear(start_date=args.execution_date, end_date=args.execution_date, reset_dag_runs=True) + dag.clear(start_date=args.execution_date, end_date=args.execution_date, dag_run_state=State.NONE) try: dag.run(executor=DebugExecutor(), start_date=args.execution_date, end_date=args.execution_date) except BackfillUnfinished as e: diff --git a/airflow/models/dag.py b/airflow/models/dag.py index e6aafd37b110c..37061219ba932 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -966,7 +966,7 @@ def clear( confirm_prompt=False, include_subdags=True, include_parentdag=True, - reset_dag_runs=True, + dag_run_state: str = State.RUNNING, dry_run=False, session=None, get_tis=False, @@ -993,8 +993,7 @@ def clear( :type include_subdags: bool :param include_parentdag: Clear tasks in the parent dag of the subdag. :type include_parentdag: bool - :param reset_dag_runs: Set state of dag to RUNNING - :type reset_dag_runs: bool + :param dag_run_state: state to set DagRun to :param dry_run: Find the tasks to clear but don't clear them. :type dry_run: bool :param session: The sqlalchemy session to use @@ -1039,7 +1038,7 @@ def clear( confirm_prompt=confirm_prompt, include_subdags=include_subdags, include_parentdag=False, - reset_dag_runs=reset_dag_runs, + dag_run_state=dag_run_state, get_tis=True, session=session, recursion_depth=recursion_depth, @@ -1103,7 +1102,7 @@ def clear( confirm_prompt=confirm_prompt, include_subdags=include_subdags, include_parentdag=False, - reset_dag_runs=reset_dag_runs, + dag_run_state=dag_run_state, get_tis=True, session=session, recursion_depth=recursion_depth + 1, @@ -1134,16 +1133,18 @@ def clear( do_it = utils.helpers.ask_yesno(question) if do_it: - clear_task_instances(tis, - session, - dag=self, - ) - if reset_dag_runs: - self.set_dag_runs_state(session=session, - start_date=start_date, - end_date=end_date, - state=State.NONE, - ) + clear_task_instances( + tis, + session, + dag=self, + activate_dag_runs=False, # We will set DagRun state later. + ) + self.set_dag_runs_state( + session=session, + start_date=start_date, + end_date=end_date, + state=dag_run_state, + ) else: count = 0 print("Bail. Nothing was cleared.") @@ -1161,7 +1162,7 @@ def clear_dags( confirm_prompt=False, include_subdags=True, include_parentdag=False, - reset_dag_runs=True, + dag_run_state=State.RUNNING, dry_run=False, ): all_tis = [] @@ -1174,7 +1175,7 @@ def clear_dags( confirm_prompt=False, include_subdags=include_subdags, include_parentdag=include_parentdag, - reset_dag_runs=reset_dag_runs, + dag_run_state=dag_run_state, dry_run=True) all_tis.extend(tis) @@ -1202,7 +1203,7 @@ def clear_dags( only_running=only_running, confirm_prompt=False, include_subdags=include_subdags, - reset_dag_runs=reset_dag_runs, + dag_run_state=dag_run_state, dry_run=False, ) else: diff --git a/airflow/providers/google/cloud/example_dags/example_datafusion.py b/airflow/providers/google/cloud/example_dags/example_datafusion.py index e2b686cfe6463..62ab1d46fdaf8 100644 --- a/airflow/providers/google/cloud/example_dags/example_datafusion.py +++ b/airflow/providers/google/cloud/example_dags/example_datafusion.py @@ -29,6 +29,7 @@ CloudDataFusionStopPipelineOperator, CloudDataFusionUpdateInstanceOperator, ) from airflow.utils import dates +from airflow.utils.state import State # [START howto_data_fusion_env_variables] LOCATION = "europe-north1" @@ -227,5 +228,5 @@ delete_pipeline >> delete_instance if __name__ == "__main__": - dag.clear(reset_dag_runs=True) + dag.clear(dag_run_state=State.NONE) dag.run() diff --git a/airflow/providers/google/cloud/example_dags/example_gcs.py b/airflow/providers/google/cloud/example_dags/example_gcs.py index 4cdac3636088e..18f173f66edbe 100644 --- a/airflow/providers/google/cloud/example_dags/example_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_gcs.py @@ -32,6 +32,7 @@ from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator from airflow.providers.google.cloud.transfers.local_to_gcs import LocalFilesystemToGCSOperator from airflow.utils.dates import days_ago +from airflow.utils.state import State default_args = {"start_date": days_ago(1)} @@ -155,5 +156,5 @@ if __name__ == '__main__': - dag.clear(reset_dag_runs=True) + dag.clear(dag_run_state=State.NONE) dag.run() diff --git a/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py b/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py index ca82c93269ae9..74fb6d328210a 100644 --- a/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py +++ b/airflow/providers/google/marketing_platform/example_dags/example_campaign_manager.py @@ -31,6 +31,7 @@ GoogleCampaignManagerReportSensor, ) from airflow.utils import dates +from airflow.utils.state import State PROFILE_ID = os.environ.get("MARKETING_PROFILE_ID", "123456789") FLOODLIGHT_ACTIVITY_ID = os.environ.get("FLOODLIGHT_ACTIVITY_ID", 12345) @@ -157,5 +158,5 @@ insert_conversion >> update_conversion if __name__ == "__main__": - dag.clear(reset_dag_runs=True) + dag.clear(dag_run_state=State.NONE) dag.run() diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 8207b2457f0c1..dd38445194300 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -55,6 +55,12 @@ class TestDag(unittest.TestCase): + def setUp(self) -> None: + clear_db_runs() + + def tearDown(self) -> None: + clear_db_runs() + @staticmethod def _clean_up(dag_id: str): with create_session() as session: @@ -1355,8 +1361,14 @@ def test_create_dagrun_run_type_is_obtained_from_run_id(self): dr = dag.create_dagrun(run_id="custom_is_set_to_manual", state=State.NONE) assert dr.run_type == DagRunType.MANUAL.value - def test_clear_reset_dagruns(self): - dag_id = 'test_clear_dag_reset_dagruns' + @parameterized.expand( + [ + (State.NONE,), + (State.RUNNING,), + ] + ) + def test_clear_set_dagrun_state(self, dag_run_state): + dag_id = 'test_clear_set_dagrun_state' self._clean_up(dag_id) task_id = 't1' dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1) @@ -1365,7 +1377,7 @@ def test_clear_reset_dagruns(self): session = settings.Session() dagrun_1 = dag.create_dagrun( run_type=DagRunType.BACKFILL_JOB, - state=State.RUNNING, + state=State.FAILED, start_date=DEFAULT_DATE, execution_date=DEFAULT_DATE, ) @@ -1378,7 +1390,7 @@ def test_clear_reset_dagruns(self): dag.clear( start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=1), - reset_dag_runs=True, + dag_run_state=dag_run_state, include_subdags=False, include_parentdag=False, session=session, @@ -1392,17 +1404,48 @@ def test_clear_reset_dagruns(self): self.assertEqual(len(dagruns), 1) dagrun = dagruns[0] # type: DagRun - self.assertEqual(dagrun.state, State.NONE) + self.assertEqual(dagrun.state, dag_run_state) + + @parameterized.expand([ + (state, State.NONE) + for state in State.task_states if state != State.RUNNING + ] + [(State.RUNNING, State.SHUTDOWN)]) + def test_clear_dag(self, ti_state_begin, ti_state_end): + dag_id = 'test_clear_dag' + self._clean_up(dag_id) + task_id = 't1' + dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1) + t_1 = DummyOperator(task_id=task_id, dag=dag) + + session = settings.Session() + dagrun_1 = dag.create_dagrun( + run_type=DagRunType.BACKFILL_JOB, + state=State.RUNNING, + start_date=DEFAULT_DATE, + execution_date=DEFAULT_DATE, + ) + session.merge(dagrun_1) + + task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=ti_state_begin) + task_instance_1.job_id = 123 + session.merge(task_instance_1) + session.commit() + + dag.clear( + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=1), + session=session, + ) task_instances = session.query( - DagRun, + TI, ).filter( - DagRun.dag_id == dag_id, + TI.dag_id == dag_id, ).all() self.assertEqual(len(task_instances), 1) task_instance = task_instances[0] # type: TI - self.assertEqual(task_instance.state, State.NONE) + self.assertEqual(task_instance.state, ti_state_end) self._clean_up(dag_id) From 8e65f6c0fd9d6af55a333ac5d85ff64f3a7a5057 Mon Sep 17 00:00:00 2001 From: Chao-Han Tsai Date: Tue, 14 Jul 2020 19:51:41 -0700 Subject: [PATCH 2/5] fix lint --- airflow/cli/commands/dag_command.py | 1 + airflow/models/dag.py | 54 ++++++++++++++--------------- tests/models/test_dag.py | 7 ++-- 3 files changed, 32 insertions(+), 30 deletions(-) diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py index 8a0f162f39094..55c40b4d812bd 100644 --- a/airflow/cli/commands/dag_command.py +++ b/airflow/cli/commands/dag_command.py @@ -124,6 +124,7 @@ def dag_backfill(args, dag=None): end_date=args.end_date, confirm_prompt=not args.yes, include_subdags=True, + dag_run_state=State.NONE, ) dag.run( diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 37061219ba932..d565b4e041d37 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -27,7 +27,7 @@ import warnings from collections import OrderedDict from datetime import datetime, timedelta -from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union +from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union, cast import jinja2 import pendulum @@ -297,7 +297,7 @@ def __init__( template_searchpath = [template_searchpath] self.template_searchpath = template_searchpath self.template_undefined = template_undefined - self.parent_dag = None # Gets set when DAGs are loaded + self.parent_dag: Optional[DAG] = None # Gets set when DAGs are loaded self.last_loaded = timezone.utcnow() self.safe_dag_id = dag_id.replace('.', '__dot__') self.max_active_runs = max_active_runs @@ -1025,26 +1025,26 @@ def clear( tis = tis.filter(TI.task_id.in_(self.task_ids)) if include_parentdag and self.is_subdag: - - p_dag = self.parent_dag.sub_dag( - task_regex=r"^{}$".format(self.dag_id.split('.')[1]), - include_upstream=False, - include_downstream=True) - - tis = tis.union(p_dag.clear( - start_date=start_date, end_date=end_date, - only_failed=only_failed, - only_running=only_running, - confirm_prompt=confirm_prompt, - include_subdags=include_subdags, - include_parentdag=False, - dag_run_state=dag_run_state, - get_tis=True, - session=session, - recursion_depth=recursion_depth, - max_recursion_depth=max_recursion_depth, - dag_bag=dag_bag - )) + if self.parent_dag is not None: + p_dag = self.parent_dag.sub_dag( + task_regex=r"^{}$".format(self.dag_id.split('.')[1]), + include_upstream=False, + include_downstream=True) + + tis = tis.union(p_dag.clear( + start_date=start_date, end_date=end_date, + only_failed=only_failed, + only_running=only_running, + confirm_prompt=confirm_prompt, + include_subdags=include_subdags, + include_parentdag=False, + dag_run_state=dag_run_state, + get_tis=True, + session=session, + recursion_depth=recursion_depth, + max_recursion_depth=max_recursion_depth, + dag_bag=dag_bag + )) if start_date: tis = tis.filter(TI.execution_date >= start_date) @@ -1064,12 +1064,12 @@ def clear( instances = tis.all() for ti in instances: if ti.operator == ExternalTaskMarker.__name__: - ti.task = self.get_task(ti.task_id) + task: ExternalTaskMarker = cast(ExternalTaskMarker, self.get_task(ti.task_id)) if recursion_depth == 0: # Maximum recursion depth allowed is the recursion_depth of the first # ExternalTaskMarker in the tasks to be cleared. - max_recursion_depth = ti.task.recursion_depth + max_recursion_depth = task.recursion_depth if recursion_depth + 1 > max_recursion_depth: # Prevent cycles or accidents. @@ -1079,10 +1079,10 @@ def clear( .format(max_recursion_depth, ExternalTaskMarker.__name__, ti.task_id)) ti.render_templates() - external_tis = session.query(TI).filter(TI.dag_id == ti.task.external_dag_id, - TI.task_id == ti.task.external_task_id, + external_tis = session.query(TI).filter(TI.dag_id == task.external_dag_id, + TI.task_id == task.external_task_id, TI.execution_date == - pendulum.parse(ti.task.execution_date)) + pendulum.parse(task.execution_date)) for tii in external_tis: if not dag_bag: diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index dd38445194300..8891d56c8e98d 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -25,6 +25,7 @@ import unittest from contextlib import redirect_stdout from tempfile import NamedTemporaryFile +from typing import Optional from unittest import mock from unittest.mock import patch @@ -1409,15 +1410,15 @@ def test_clear_set_dagrun_state(self, dag_run_state): @parameterized.expand([ (state, State.NONE) for state in State.task_states if state != State.RUNNING - ] + [(State.RUNNING, State.SHUTDOWN)]) - def test_clear_dag(self, ti_state_begin, ti_state_end): + ] + [(State.RUNNING, State.SHUTDOWN)]) # type: ignore + def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]): dag_id = 'test_clear_dag' self._clean_up(dag_id) task_id = 't1' dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1) t_1 = DummyOperator(task_id=task_id, dag=dag) - session = settings.Session() + session = settings.Session() # type: ignore dagrun_1 = dag.create_dagrun( run_type=DagRunType.BACKFILL_JOB, state=State.RUNNING, From 33f0a4bdf620bacd1973c7346211785a96e7183a Mon Sep 17 00:00:00 2001 From: Chao-Han Tsai Date: Tue, 14 Jul 2020 21:19:19 -0700 Subject: [PATCH 3/5] fix tests --- airflow/models/dag.py | 1 + tests/cli/commands/test_dag_command.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index d565b4e041d37..d2e907cd9457d 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1065,6 +1065,7 @@ def clear( for ti in instances: if ti.operator == ExternalTaskMarker.__name__: task: ExternalTaskMarker = cast(ExternalTaskMarker, self.get_task(ti.task_id)) + ti.task = self.get_task(ti.task_id) if recursion_depth == 0: # Maximum recursion depth allowed is the recursion_depth of the first diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index fa4a32e50ccbc..6dda923931ae6 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -433,7 +433,8 @@ def test_dag_test(self, mock_get_dag, mock_executor): subdir=cli_args.subdir, dag_id='example_bash_operator' ), mock.call().clear( - start_date=cli_args.execution_date, end_date=cli_args.execution_date, reset_dag_runs=True + start_date=cli_args.execution_date, end_date=cli_args.execution_date, + dag_run_state=State.NONE, ), mock.call().run( executor=mock_executor.return_value, @@ -461,7 +462,9 @@ def test_dag_test_show_dag(self, mock_get_dag, mock_executor, mock_render_dag): subdir=cli_args.subdir, dag_id='example_bash_operator' ), mock.call().clear( - start_date=cli_args.execution_date, end_date=cli_args.execution_date, reset_dag_runs=True + start_date=cli_args.execution_date, + end_date=cli_args.execution_date, + dag_run_state=State.NONE, ), mock.call().run( executor=mock_executor.return_value, From 6c40027d9efb2353d169a491e6bc9715f1d967b5 Mon Sep 17 00:00:00 2001 From: Chao-Han Tsai Date: Tue, 14 Jul 2020 21:49:19 -0700 Subject: [PATCH 4/5] assign var --- airflow/models/dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index d2e907cd9457d..63994e8447013 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1065,7 +1065,7 @@ def clear( for ti in instances: if ti.operator == ExternalTaskMarker.__name__: task: ExternalTaskMarker = cast(ExternalTaskMarker, self.get_task(ti.task_id)) - ti.task = self.get_task(ti.task_id) + ti.task = task if recursion_depth == 0: # Maximum recursion depth allowed is the recursion_depth of the first From 8e6c807bf36df55c712dd9f48aed28ebffba6340 Mon Sep 17 00:00:00 2001 From: Chao-Han Tsai Date: Wed, 15 Jul 2020 10:26:55 -0700 Subject: [PATCH 5/5] extend original clause --- airflow/models/dag.py | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 63994e8447013..dfb6409c69003 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1024,27 +1024,26 @@ def clear( tis = session.query(TI).filter(TI.dag_id == self.dag_id) tis = tis.filter(TI.task_id.in_(self.task_ids)) - if include_parentdag and self.is_subdag: - if self.parent_dag is not None: - p_dag = self.parent_dag.sub_dag( - task_regex=r"^{}$".format(self.dag_id.split('.')[1]), - include_upstream=False, - include_downstream=True) - - tis = tis.union(p_dag.clear( - start_date=start_date, end_date=end_date, - only_failed=only_failed, - only_running=only_running, - confirm_prompt=confirm_prompt, - include_subdags=include_subdags, - include_parentdag=False, - dag_run_state=dag_run_state, - get_tis=True, - session=session, - recursion_depth=recursion_depth, - max_recursion_depth=max_recursion_depth, - dag_bag=dag_bag - )) + if include_parentdag and self.is_subdag and self.parent_dag is not None: + p_dag = self.parent_dag.sub_dag( + task_regex=r"^{}$".format(self.dag_id.split('.')[1]), + include_upstream=False, + include_downstream=True) + + tis = tis.union(p_dag.clear( + start_date=start_date, end_date=end_date, + only_failed=only_failed, + only_running=only_running, + confirm_prompt=confirm_prompt, + include_subdags=include_subdags, + include_parentdag=False, + dag_run_state=dag_run_state, + get_tis=True, + session=session, + recursion_depth=recursion_depth, + max_recursion_depth=max_recursion_depth, + dag_bag=dag_bag + )) if start_date: tis = tis.filter(TI.execution_date >= start_date)