Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions airflow/jobs/triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,16 @@ def update_triggers(self, requested_trigger_ids: Set[int]):
# line's execution, but we consider that safe, since there's a strict
# add -> remove -> never again lifecycle this function is already
# handling.
current_trigger_ids = set(self.triggers.keys())
running_trigger_ids = set(self.triggers.keys())
known_trigger_ids = (
running_trigger_ids.union(x[0] for x in self.events)
.union(self.to_cancel)
.union(x[0] for x in self.to_create)
.union(self.failed_triggers)
)
# Work out the two difference sets
new_trigger_ids = requested_trigger_ids.difference(current_trigger_ids)
cancel_trigger_ids = current_trigger_ids.difference(requested_trigger_ids)
new_trigger_ids = requested_trigger_ids - known_trigger_ids
cancel_trigger_ids = running_trigger_ids - requested_trigger_ids
# Bulk-fetch new trigger records
new_triggers = Trigger.bulk_fetch(new_trigger_ids)
# Add in new triggers
Expand Down
136 changes: 133 additions & 3 deletions tests/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,53 @@
# specific language governing permissions and limitations
# under the License.

import asyncio
import datetime
import sys
import time
from threading import Thread

import pytest

from airflow.jobs.triggerer_job import TriggererJob
from airflow.models import Trigger
from airflow.jobs.triggerer_job import TriggererJob, TriggerRunner
from airflow.models import DagModel, DagRun, TaskInstance, Trigger
from airflow.operators.dummy import DummyOperator
from airflow.operators.python import PythonOperator
from airflow.triggers.base import TriggerEvent
from airflow.triggers.temporal import TimeDeltaTrigger
from airflow.triggers.testing import FailureTrigger, SuccessTrigger
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State, TaskInstanceState
from tests.test_utils.db import clear_db_runs
from tests.test_utils.db import clear_db_dags, clear_db_runs


class TimeDeltaTrigger_(TimeDeltaTrigger):
def __init__(self, delta, filename):
super().__init__(delta=delta)
self.filename = filename
self.delta = delta

async def run(self):
with open(self.filename, 'at') as f:
f.write('hi\n')
async for event in super().run():
yield event

def serialize(self):
return (
"tests.jobs.test_triggerer_job.TimeDeltaTrigger_",
{"delta": self.delta, "filename": self.filename},
)


@pytest.fixture(autouse=True)
def clean_database():
"""Fixture that cleans the database before and after every test."""
clear_db_runs()
clear_db_dags()
yield # Test runs here
clear_db_dags()
clear_db_runs()


Expand Down Expand Up @@ -159,6 +183,112 @@ def test_trigger_lifecycle(session):
job.runner.stop = True


@pytest.mark.skipif(sys.version_info.minor <= 6 and sys.version_info.major <= 3, reason="No triggerer on 3.6")
def test_trigger_create_race_condition_18392(session, tmp_path):
"""
This verifies the resolution of race condition documented in github issue #18392.
Triggers are queued for creation by TriggerJob.load_triggers.
There was a race condition where multiple triggers would be created unnecessarily.
What happens is the runner completes the trigger and purges from the "running" list.
Then job.load_triggers is called and it looks like the trigger is not running but should,
so it queues it again.

The scenario is as follows:
1. job.load_triggers (trigger now queued)
2. runner.create_triggers (trigger now running)
3. job.handle_events (trigger still appears running so state not updated in DB)
4. runner.cleanup_finished_triggers (trigger completed at this point; trigger from "running" set)
5. job.load_triggers (trigger not running, but also not purged from DB, so it is queued again)
6. runner.create_triggers (trigger created again)

This test verifies that under this scenario only one trigger is created.
"""
path = tmp_path / 'test_trigger_bad_respawn.txt'

class TriggerRunner_(TriggerRunner):
"""We do some waiting for main thread looping"""

async def wait_for_job_method_count(self, method, count):
for _ in range(30):
await asyncio.sleep(0.1)
if getattr(self, f'{method}_count', 0) >= count:
break
else:
pytest.fail(f"did not observe count {count} in job method {method}")

async def create_triggers(self):
"""
On first run, wait for job.load_triggers to make sure they are queued
"""
if getattr(self, 'loop_count', 0) == 0:
await self.wait_for_job_method_count('load_triggers', 1)
await super().create_triggers()
self.loop_count = getattr(self, 'loop_count', 0) + 1

async def cleanup_finished_triggers(self):
"""On loop 1, make sure that job.handle_events was already called"""
if self.loop_count == 1:
await self.wait_for_job_method_count('handle_events', 1)
await super().cleanup_finished_triggers()

class TriggererJob_(TriggererJob):
"""We do some waiting for runner thread looping (and track calls in job thread)"""

def wait_for_runner_loop(self, runner_loop_count):
for _ in range(30):
time.sleep(0.1)
if getattr(self.runner, 'call_count', 0) >= runner_loop_count:
break
else:
pytest.fail("did not observe 2 loops in the runner thread")

def load_triggers(self):
"""On second run, make sure that runner has called create_triggers in its second loop"""
super().load_triggers()
self.runner.load_triggers_count = getattr(self.runner, 'load_triggers_count', 0) + 1
if self.runner.load_triggers_count == 2:
self.wait_for_runner_loop(runner_loop_count=2)

def handle_events(self):
super().handle_events()
self.runner.handle_events_count = getattr(self.runner, 'handle_events_count', 0) + 1

trigger = TimeDeltaTrigger_(delta=datetime.timedelta(microseconds=1), filename=path.as_posix())
trigger_orm = Trigger.from_object(trigger)
trigger_orm.id = 1
session.add(trigger_orm)

dag = DagModel(dag_id='test-dag')
dag_run = DagRun(dag.dag_id, run_id='abc', run_type='none')
ti = TaskInstance(PythonOperator(task_id='dummy-task', python_callable=print), run_id=dag_run.run_id)
ti.dag_id = dag.dag_id
ti.trigger_id = 1
session.add(dag)
session.add(dag_run)
session.add(ti)

session.commit()

job = TriggererJob_()
job.runner = TriggerRunner_()
thread = Thread(target=job._execute)
thread.start()
try:
for _ in range(40):
time.sleep(0.1)
# ready to evaluate after 2 loops
if getattr(job.runner, 'loop_count', 0) >= 2:
break
else:
pytest.fail("did not observe 2 loops in the runner thread")
finally:
job.runner.stop = True
job.runner.join()
thread.join()
instances = path.read_text().splitlines()
assert len(instances) == 1


@pytest.mark.skipif(sys.version_info.minor <= 6 and sys.version_info.major <= 3, reason="No triggerer on 3.6")
def test_trigger_from_dead_triggerer(session):
"""
Expand Down