diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py index f72dfc52fbb67..9d6227a695cc2 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py @@ -611,7 +611,7 @@ def handle_bulk_delete( try: # Handle deletion of specific (dag_id, dag_run_id, task_id, map_index) tuples if delete_specific_map_index_task_keys: - _, matched_task_keys, not_found_task_keys = self._categorize_task_instances( + task_instances_map, matched_task_keys, not_found_task_keys = self._categorize_task_instances( delete_specific_map_index_task_keys ) not_found_task_ids = [ @@ -625,23 +625,10 @@ def handle_bulk_delete( detail=f"The task instances with these identifiers: {not_found_task_ids} were not found", ) - for dag_id, run_id, task_id, map_index in matched_task_keys: - ti = ( - self.session.execute( - select(TI).where( - TI.dag_id == dag_id, - TI.run_id == run_id, - TI.task_id == task_id, - TI.map_index == map_index, - ) - ) - .scalars() - .one_or_none() - ) - - if ti: - self.session.delete(ti) - results.success.append(f"{dag_id}.{run_id}.{task_id}[{map_index}]") + for task_key in matched_task_keys: + dag_id, run_id, task_id, map_index = task_key + self.session.delete(task_instances_map[task_key]) + results.success.append(f"{dag_id}.{run_id}.{task_id}[{map_index}]") # Handle deletion of all map indexes for certain (dag_id, dag_run_id, task_id) tuples if delete_all_map_index_task_keys: diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py index d49e36c8c63ba..1e874f925b30d 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py @@ -6719,6 +6719,46 @@ def test_bulk_delete_rejects_unauthorized_dag_ids_from_request_body(self, test_c } ] + @pytest.mark.parametrize("task_count", [5, 10, 20]) + def test_bulk_delete_query_count_scales_linearly_with_task_count(self, test_client, session, task_count): + # Regression guard for the N+1 fix in BulkTaskInstanceService.handle_bulk_delete: + # each extra task instance must add exactly QUERIES_PER_TASK_INSTANCE query (its DELETE), + # not 2 (DELETE + re-SELECT). A regression that re-queries inside the loop would make + # each run strictly exceed BASE_QUERY_COUNT + task_count * QUERIES_PER_TASK_INSTANCE. + QUERIES_PER_TASK_INSTANCE = 1 + BASE_QUERY_COUNT = 5 + + self.create_task_instances( + session, + task_instances=[{"state": State.RUNNING, "map_indexes": tuple(range(task_count))}], + ) + request_body = { + "actions": [ + { + "action": "delete", + "entities": [ + {"task_id": self.TASK_ID, "map_index": map_index} for map_index in range(task_count) + ], + "action_on_non_existence": "fail", + } + ] + } + + with count_queries() as result: + response = test_client.patch(self.ENDPOINT_URL, json=request_body) + + assert response.status_code == 200 + assert len(response.json()["delete"]["success"]) == task_count + + query_count = sum(result.values()) + expected_query_count = BASE_QUERY_COUNT + task_count * QUERIES_PER_TASK_INSTANCE + assert query_count == expected_query_count, ( + f"Bulk-delete query count {query_count} does not match expected {expected_query_count} " + f"for {task_count} task instances. " + f"A regression that re-queries each task instance would give " + f"~{BASE_QUERY_COUNT + task_count * 2} queries instead." + ) + def test_should_respond_401(self, unauthenticated_test_client): response = unauthenticated_test_client.patch(self.ENDPOINT_URL, json={}) assert response.status_code == 401