Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance
return tis

states_by_celery_task_id = self.bulk_state_fetcher.get_many(
map(operator.itemgetter(0), celery_tasks.values())
list(map(operator.itemgetter(0), celery_tasks.values()))
)

adopted = []
Expand Down Expand Up @@ -526,10 +526,6 @@ def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str,
return async_result.task_id, ExceptionWithTraceback(e, exception_traceback), None


def _tasks_list_to_task_ids(async_tasks) -> Set[str]:
return {a.task_id for a in async_tasks}


class BulkStateFetcher(LoggingMixin):
"""
Gets status for many Celery tasks using the best method available
Expand All @@ -543,20 +539,22 @@ def __init__(self, sync_parallelism=None):
super().__init__()
self._sync_parallelism = sync_parallelism

def _tasks_list_to_task_ids(self, async_tasks) -> Set[str]:
return {a.task_id for a in async_tasks}

def get_many(self, async_results) -> Mapping[str, EventBufferValueType]:
"""Gets status for many Celery tasks using the best method available."""
if isinstance(app.backend, BaseKeyValueStoreBackend):
result = self._get_many_from_kv_backend(async_results)
return result
if isinstance(app.backend, DatabaseBackend):
elif isinstance(app.backend, DatabaseBackend):
result = self._get_many_from_db_backend(async_results)
return result
result = self._get_many_using_multiprocessing(async_results)
self.log.debug("Fetched %d states for %d task", len(result), len(async_results))
else:
result = self._get_many_using_multiprocessing(async_results)
self.log.debug("Fetched %d state(s) for %d task(s)", len(result), len(async_results))
return result

def _get_many_from_kv_backend(self, async_tasks) -> Mapping[str, EventBufferValueType]:
task_ids = _tasks_list_to_task_ids(async_tasks)
task_ids = self._tasks_list_to_task_ids(async_tasks)
keys = [app.backend.get_key_for_task(k) for k in task_ids]
values = app.backend.mget(keys)
task_results = [app.backend.decode_result(v) for v in values if v]
Expand All @@ -565,7 +563,7 @@ def _get_many_from_kv_backend(self, async_tasks) -> Mapping[str, EventBufferValu
return self._prepare_state_and_info_by_task_dict(task_ids, task_results_by_task_id)

def _get_many_from_db_backend(self, async_tasks) -> Mapping[str, EventBufferValueType]:
task_ids = _tasks_list_to_task_ids(async_tasks)
task_ids = self._tasks_list_to_task_ids(async_tasks)
session = app.backend.ResultSession()
task_cls = getattr(app.backend, "task_cls", TaskDb)
with session_cleanup(session):
Expand Down
35 changes: 25 additions & 10 deletions tests/executors/test_celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,9 @@ class TestBulkStateFetcher(unittest.TestCase):
def test_should_support_kv_backend(self, mock_mget):
with _prepare_app():
mock_backend = BaseKeyValueStoreBackend(app=celery_executor.app)
with mock.patch.object(celery_executor.app, 'backend', mock_backend):
with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs(
"airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG"
) as cm:
fetcher = BulkStateFetcher()
result = fetcher.get_many(
[
Expand All @@ -427,6 +429,9 @@ def test_should_support_kv_backend(self, mock_mget):
mock_mget.assert_called_once_with(mock.ANY)

assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
assert [
'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)'
] == cm.output

@mock.patch("celery.backends.database.DatabaseBackend.ResultSession")
@pytest.mark.integration("redis")
Expand All @@ -436,21 +441,26 @@ def test_should_support_db_backend(self, mock_session):
with _prepare_app():
mock_backend = DatabaseBackend(app=celery_executor.app, url="sqlite3://")

with mock.patch.object(celery_executor.app, 'backend', mock_backend):
with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs(
"airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG"
) as cm:
mock_session = mock_backend.ResultSession.return_value # pylint: disable=no-member
mock_session.query.return_value.filter.return_value.all.return_value = [
mock.MagicMock(**{"to_dict.return_value": {"status": "SUCCESS", "task_id": "123"}})
]

fetcher = BulkStateFetcher()
result = fetcher.get_many(
[
mock.MagicMock(task_id="123"),
mock.MagicMock(task_id="456"),
]
)
fetcher = BulkStateFetcher()
result = fetcher.get_many(
[
mock.MagicMock(task_id="123"),
mock.MagicMock(task_id="456"),
]
)

assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
assert [
'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)'
] == cm.output

@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
Expand All @@ -459,7 +469,9 @@ def test_should_support_base_backend(self):
with _prepare_app():
mock_backend = mock.MagicMock(autospec=BaseBackend)

with mock.patch.object(celery_executor.app, 'backend', mock_backend):
with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs(
"airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG"
) as cm:
fetcher = BulkStateFetcher(1)
result = fetcher.get_many(
[
Expand All @@ -469,3 +481,6 @@ def test_should_support_base_backend(self):
)

assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
assert [
'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)'
] == cm.output