Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 18 additions & 35 deletions providers/celery/tests/integration/celery/test_celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ def _prepare_app(broker_url=None, execute=None):
set_event_loop(None)


def setup_dagrun_with_success_and_fail_workloads(dag_maker):
date = timezone.utcnow()
start_date = date - timedelta(days=2)

with dag_maker("test_celery_integration"):
BaseOperator(task_id="success", start_date=start_date)
BaseOperator(task_id="fail", start_date=start_date)

return dag_maker.create_dagrun(logical_date=date)


@pytest.mark.integration("celery")
@pytest.mark.backend("postgres")
class TestCeleryExecutor:
Expand All @@ -128,17 +139,6 @@ def teardown_method(self) -> None:
db.clear_db_runs()
db.clear_db_jobs()


def setup_dagrun_with_success_and_fail_workloads(dag_maker):
date = timezone.utcnow()
start_date = date - timedelta(days=2)

with dag_maker("test_celery_integration"):
BaseOperator(task_id="success", start_date=start_date)
BaseOperator(task_id="fail", start_date=start_date)

return dag_maker.create_dagrun(logical_date=date)

@pytest.mark.flaky(reruns=5, reruns_delay=3)
@pytest.mark.parametrize("broker_url", _prepare_test_bodies())
@pytest.mark.parametrize(
Expand Down Expand Up @@ -196,16 +196,11 @@ def fake_execute(input: str) -> None: # Use same parameter name as Airflow 3 ve
# Force single-process sending so mock patches survive (ProcessPoolExecutor
# would fork new processes where the patches are not active).
executor._sync_parallelism = 1
assert executor.tasks == {}
assert executor.workloads == {}
executor.start()

with start_worker(app=app, logfile=sys.stdout, loglevel="info"):
dagrun_date = timezone.utcnow()
dagrun_start = dagrun_date - timedelta(days=2)
with dag_maker("test_celery_integration"):
BaseOperator(task_id="success", start_date=dagrun_start)
BaseOperator(task_id="fail", start_date=dagrun_start)
dagrun = dag_maker.create_dagrun(logical_date=dagrun_date)
dagrun = setup_dagrun_with_success_and_fail_workloads(dag_maker)
ti_fail, ti_success = sorted(dagrun.task_instances, key=lambda ti: ti.task_id)
# Derive keys from the real task instances so they match what the executor tracks
key_fail = TaskInstanceKey(
Expand All @@ -229,31 +224,19 @@ def fake_execute(input: str) -> None: # Use same parameter name as Airflow 3 ve
bundle_info=BundleInfo(name="test"),
log_path="test.log",
)
keys = [
TaskInstanceKey("id", "success", "abc", 0, -1),
TaskInstanceKey("id", "fail", "abc", 0, -1),
]
dagrun = setup_dagrun_with_success_and_fail_workloads(dag_maker)
ti_success, ti_fail = dagrun.task_instances
for w in (
workloads.ExecuteTask.make(
ti=ti_success,
),
workloads.ExecuteTask.make(ti=ti_fail),
):
executor.queue_workload(w, session=None)

executor.trigger_tasks(open_slots=10)
for _ in range(20):
num_tasks = len(executor.tasks.keys())
num_tasks = len(executor.workloads.keys())
if num_tasks == 2:
break
logger.info(
"Waiting 0.1 s for tasks to be processed asynchronously. Processed so far %d",
num_tasks,
)
sleep(0.4)
assert sorted(executor.tasks.keys()) == sorted(keys)
assert sorted(executor.workloads.keys()) == sorted(keys)
assert executor.event_buffer[key_success][0] == State.QUEUED
assert executor.event_buffer[key_fail][0] == State.QUEUED

Expand All @@ -262,8 +245,8 @@ def fake_execute(input: str) -> None: # Use same parameter name as Airflow 3 ve
assert executor.event_buffer[key_success][0] == State.SUCCESS
assert executor.event_buffer[key_fail][0] == State.FAILED

assert key_success not in executor.tasks
assert key_fail not in executor.tasks
assert key_success not in executor.workloads
assert key_fail not in executor.workloads

assert executor.queued_tasks == {}

Expand All @@ -284,7 +267,7 @@ def test_error_sending_workload(self):

key = (task.dag.dag_id, task.task_id, ti.run_id, 0, -1)
executor.queued_tasks[key] = workload
executor.task_publish_retries[key] = 1
executor.workload_publish_retries[key] = 1

# Mock send_workload_to_executor to return an error result.
# This simulates a failure when sending the workload to Celery.
Expand Down
Loading