From 4e069316e57008e3b9841853c58c3644fd03d94c Mon Sep 17 00:00:00 2001 From: vincbeck Date: Mon, 25 Nov 2024 17:06:17 -0500 Subject: [PATCH] AIP-82 Send asset change event when trigger fires --- airflow/jobs/triggerer_job_runner.py | 20 +++-- airflow/models/trigger.py | 49 +++++++++++-- tests/models/test_trigger.py | 105 +++++++++++++++++++++++---- 3 files changed, 144 insertions(+), 30 deletions(-) diff --git a/airflow/jobs/triggerer_job_runner.py b/airflow/jobs/triggerer_job_runner.py index c52a7514346ea..e44c6709d4b51 100644 --- a/airflow/jobs/triggerer_job_runner.py +++ b/airflow/jobs/triggerer_job_runner.py @@ -530,11 +530,15 @@ async def create_triggers(self): while self.to_create: trigger_id, trigger_instance = self.to_create.popleft() if trigger_id not in self.triggers: - ti: TaskInstance = trigger_instance.task_instance + ti: TaskInstance | None = trigger_instance.task_instance + trigger_name = ( + f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} (ID {trigger_id})" + if ti + else f"ID {trigger_id}" + ) self.triggers[trigger_id] = { "task": asyncio.create_task(self.run_trigger(trigger_id, trigger_instance)), - "name": f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} " - f"(ID {trigger_id})", + "name": trigger_name, "events": 0, } else: @@ -636,13 +640,14 @@ async def run_trigger(self, trigger_id, trigger): name = self.triggers[trigger_id]["name"] self.log.info("trigger %s starting", name) try: - self.set_individual_trigger_logging(trigger) + if trigger.task_instance: + self.set_individual_trigger_logging(trigger) async for event in trigger.run(): self.log.info("Trigger %s fired: %s", self.triggers[trigger_id]["name"], event) self.triggers[trigger_id]["events"] += 1 self.events.append((trigger_id, event)) except asyncio.CancelledError: - if timeout := trigger.task_instance.trigger_timeout: + if timeout := trigger.task_instance and trigger.task_instance.trigger_timeout: timeout = timeout.replace(tzinfo=timezone.utc) if not timeout.tzinfo else timeout if timeout < timezone.utcnow(): self.log.error("Trigger cancelled due to timeout") @@ -696,6 +701,7 @@ def update_triggers(self, requested_trigger_ids: set[int]): cancel_trigger_ids = running_trigger_ids - requested_trigger_ids # Bulk-fetch new trigger records new_triggers = Trigger.bulk_fetch(new_trigger_ids) + triggers_with_assets = Trigger.fetch_trigger_ids_with_asset() # Add in new triggers for new_id in new_trigger_ids: # Check it didn't vanish in the meantime @@ -711,11 +717,11 @@ def update_triggers(self, requested_trigger_ids: set[int]): self.failed_triggers.append((new_id, e)) continue - # If new_trigger_orm.task_instance is None, this means the TaskInstance + # If the trigger is not associated to a task or an asset, this means the TaskInstance # row was updated by either Trigger.submit_event or Trigger.submit_failure # and can happen when a single trigger Job is being run on multiple TriggerRunners # in a High-Availability setup. - if new_trigger_orm.task_instance is None: + if new_trigger_orm.task_instance is None and new_id not in triggers_with_assets: self.log.info( ( "TaskInstance for Trigger ID %s is None. It was likely updated by another trigger job. " diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index b7b6ba9980d51..a53c5d4a448c9 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -25,6 +25,7 @@ from sqlalchemy.orm import relationship, selectinload from sqlalchemy.sql.functions import coalesce +from airflow.assets.manager import AssetManager from airflow.models.asset import asset_trigger_association_table from airflow.models.base import Base from airflow.models.taskinstance import TaskInstance @@ -158,15 +159,21 @@ def bulk_fetch(cls, ids: Iterable[int], session: Session = NEW_SESSION) -> dict[ ) return {obj.id: obj for obj in session.scalars(stmt)} + @classmethod + @provide_session + def fetch_trigger_ids_with_asset(cls, session: Session = NEW_SESSION) -> set[str]: + """Fetch all the trigger IDs associated with at least one asset.""" + query = select(asset_trigger_association_table.columns.trigger_id) + return {trigger_id for trigger_id in session.scalars(query)} + @classmethod @provide_session def clean_unused(cls, session: Session = NEW_SESSION) -> None: """ - Delete all triggers that have no tasks dependent on them. + Delete all triggers that have no tasks dependent on them and are not associated to an asset. - Triggers have a one-to-many relationship to task instances, so we need - to clean those up first. Afterwards we can drop the triggers not - referenced by anyone. + Triggers have a one-to-many relationship to task instances, so we need to clean those up first. + Afterward we can drop the triggers not referenced by anyone. """ # Update all task instances with trigger IDs that are not DEFERRED to remove them for attempt in run_with_db_retries(): @@ -179,9 +186,10 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None: .values(trigger_id=None) ) - # Get all triggers that have no task instances depending on them and delete them + # Get all triggers that have no task instances and assets depending on them and delete them ids = ( select(cls.id) + .where(~cls.assets.any()) .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True) .group_by(cls.id) .having(func.count(TaskInstance.trigger_id) == 0) @@ -196,7 +204,13 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None: @classmethod @provide_session def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION) -> None: - """Take an event from an instance of itself, and trigger all dependent tasks to resume.""" + """ + Fire an event. + + Resume all tasks that were in deferred state. + Send an event to all assets associated to the trigger. + """ + # Resume deferred tasks for task_instance in session.scalars( select(TaskInstance).where( TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED @@ -204,6 +218,14 @@ def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION) -> None ): event.handle_submit(task_instance=task_instance) + # Send an event to assets + trigger = session.scalars(select(cls).where(cls.id == trigger_id)).one() + for asset in trigger.assets: + AssetManager.register_asset_change( + asset=asset.to_public(), + session=session, + ) + @classmethod @provide_session def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> None: @@ -239,7 +261,7 @@ def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> @classmethod @provide_session def ids_for_triggerer(cls, triggerer_id, session: Session = NEW_SESSION) -> list[int]: - """Retrieve a list of triggerer_ids.""" + """Retrieve a list of trigger ids.""" return session.scalars(select(cls.id).where(cls.triggerer_id == triggerer_id)).all() @classmethod @@ -301,4 +323,15 @@ def get_sorted_triggers(cls, capacity: int, alive_triggerer_ids: list[int] | Sel session, skip_locked=True, ) - return session.execute(query).all() + ti_triggers = session.execute(query).all() + + query = with_row_locks( + select(cls.id).where(cls.assets.any()).order_by(cls.created_date).limit(capacity), + session, + skip_locked=True, + ) + asset_triggers = session.execute(query).all() + + # Add triggers associated to assets after triggers associated to tasks + # It prioritizes DAGs over event driven scheduling which is fair + return ti_triggers + asset_triggers diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py index 235c65857989b..97c2b10208253 100644 --- a/tests/models/test_trigger.py +++ b/tests/models/test_trigger.py @@ -30,6 +30,7 @@ from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner from airflow.models import TaskInstance, Trigger, XCom +from airflow.models.asset import AssetEvent, AssetModel, asset_trigger_association_table from airflow.operators.empty import EmptyOperator from airflow.serialization.serialized_objects import BaseSerialization from airflow.triggers.base import ( @@ -59,48 +60,92 @@ def session(): @pytest.fixture(autouse=True) def clear_db(session): session.query(TaskInstance).delete() + session.query(asset_trigger_association_table).delete() session.query(Trigger).delete() + session.query(AssetModel).delete() + session.query(AssetEvent).delete() session.query(Job).delete() yield session session.query(TaskInstance).delete() + session.query(asset_trigger_association_table).delete() session.query(Trigger).delete() + session.query(AssetModel).delete() + session.query(AssetEvent).delete() session.query(Job).delete() session.commit() +def test_fetch_trigger_ids_with_asset(session): + # Create triggers + trigger1 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger1", kwargs={}) + trigger1.id = 1 + trigger2 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger2", kwargs={}) + trigger2.id = 2 + session.add(trigger1) + session.add(trigger2) + # Create assets + asset = AssetModel("test") + asset.triggers.extend([trigger1]) + session.add(asset) + session.commit() + + results = Trigger.fetch_trigger_ids_with_asset() + assert results == {1} + + def test_clean_unused(session, create_task_instance): """ Tests that unused triggers (those with no task instances referencing them) are cleaned out automatically. """ - # Make three triggers - trigger1 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}) + # Create triggers + trigger1 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger1", kwargs={}) trigger1.id = 1 - trigger2 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}) + trigger2 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger2", kwargs={}) trigger2.id = 2 - trigger3 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}) + trigger3 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger3", kwargs={}) trigger3.id = 3 + trigger4 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger4", kwargs={}) + trigger4.id = 4 + trigger5 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger5", kwargs={}) + trigger5.id = 5 session.add(trigger1) session.add(trigger2) session.add(trigger3) + session.add(trigger4) + session.add(trigger5) session.commit() - assert session.query(Trigger).count() == 3 + assert session.query(Trigger).count() == 5 # Tie one to a fake TaskInstance that is not deferred, and one to one that is task_instance = create_task_instance( session=session, task_id="fake", state=State.DEFERRED, logical_date=timezone.utcnow() ) task_instance.trigger_id = trigger1.id session.add(task_instance) - fake_task = EmptyOperator(task_id="fake2", dag=task_instance.task.dag) - task_instance = TaskInstance(task=fake_task, run_id=task_instance.run_id) - task_instance.state = State.SUCCESS - task_instance.trigger_id = trigger2.id - session.add(task_instance) + fake_task1 = EmptyOperator(task_id="fake2", dag=task_instance.task.dag) + task_instance1 = TaskInstance(task=fake_task1, run_id=task_instance.run_id) + task_instance1.state = State.SUCCESS + task_instance1.trigger_id = trigger2.id + session.add(task_instance1) + fake_task2 = EmptyOperator(task_id="fake3", dag=task_instance.task.dag) + task_instance2 = TaskInstance(task=fake_task2, run_id=task_instance.run_id) + task_instance2.state = State.SUCCESS + task_instance2.trigger_id = trigger4.id + session.add(task_instance2) + session.commit() + + # Create assets + asset = AssetModel("test") + asset.triggers.extend([trigger4, trigger5]) + session.add(asset) session.commit() + assert session.query(AssetModel).count() == 1 + # Run clear operation Trigger.clean_unused() - # Verify that one trigger is gone, and the right one is left - assert session.query(Trigger).one().id == trigger1.id + results = session.query(Trigger).all() + assert len(results) == 3 + assert {result.id for result in results} == {1, 4, 5} def test_submit_event(session, create_task_instance): @@ -120,6 +165,15 @@ def test_submit_event(session, create_task_instance): task_instance.trigger_id = trigger.id task_instance.next_kwargs = {"cheesecake": True} session.commit() + # Create assets + asset = AssetModel("test") + asset.id = 1 + asset.triggers.extend([trigger]) + session.add(asset) + session.commit() + + # Check that the asset has 0 event prior to sending an event to the trigger + assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 0 # Call submit_event Trigger.submit_event(trigger.id, TriggerEvent(42), session=session) # commit changes made by submit event and expire all cache to read from db. @@ -128,6 +182,8 @@ def test_submit_event(session, create_task_instance): updated_task_instance = session.query(TaskInstance).one() assert updated_task_instance.state == State.SCHEDULED assert updated_task_instance.next_kwargs == {"event": 42, "cheesecake": True} + # Check that the asset has received an event + assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 1 def test_submit_failure(session, create_task_instance): @@ -349,13 +405,32 @@ def test_get_sorted_triggers_same_priority_weight(session, create_task_instance) TI_new.priority_weight = 1 TI_new.trigger_id = trigger_new.id session.add(TI_new) - + trigger_orphan = Trigger( + classpath="airflow.triggers.testing.TriggerOrphan", + kwargs={}, + created_date=new_logical_date, + ) + trigger_orphan.id = 3 + session.add(trigger_orphan) + trigger_asset = Trigger( + classpath="airflow.triggers.testing.TriggerAsset", + kwargs={}, + created_date=new_logical_date, + ) + trigger_asset.id = 4 + session.add(trigger_asset) + session.commit() + assert session.query(Trigger).count() == 4 + # Create assets + asset = AssetModel("test") + asset.id = 1 + asset.triggers.extend([trigger_asset]) + session.add(asset) session.commit() - assert session.query(Trigger).count() == 2 trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[], session=session) - assert trigger_ids_query == [(1,), (2,)] + assert trigger_ids_query == [(1,), (2,), (4,)] def test_get_sorted_triggers_different_priority_weights(session, create_task_instance):