From 020d8c2c2352c2e491bf8f7f7788431e729a5fa4 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 10 Jun 2025 16:20:37 +0530 Subject: [PATCH 1/8] Improve xcom_pull to reflect reality for mapped tasks --- task-sdk/src/airflow/sdk/bases/xcom.py | 64 ++++++++++++++++++- .../airflow/sdk/execution_time/task_runner.py | 28 +++++--- 2 files changed, 82 insertions(+), 10 deletions(-) diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 0c330652956f2..fb68cb8274401 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -21,7 +21,14 @@ import structlog -from airflow.sdk.execution_time.comms import DeleteXCom, GetXCom, SetXCom, XComResult +from airflow.sdk.execution_time.comms import ( + DeleteXCom, + GetXCom, + GetXComSequenceSlice, + SetXCom, + XComResult, + XComSequenceSliceResult, +) log = structlog.get_logger(logger_name="task") @@ -274,6 +281,61 @@ def get_one( ) return None + @classmethod + def get_all( + cls, + *, + key: str, + dag_id: str, + task_id: str, + run_id: str, + ) -> Any | None: + """ + Retrieve all XCom values for a task, typically from all map indexes. + + This method returns "full" XCom values (i.e. uses ``deserialize_value`` + from the XCom backend). + + If there are no results, *None* is returned. If XCom entries exist, + a list containing all matching XCom values is returned. + + This is particularly useful for getting all XCom values from all map + indexes of a mapped task at once. + + :param key: A key for the XCom. Only XComs with this key will be returned. + :param run_id: DAG run ID for the task. + :param dag_id: DAG ID to pull XComs from. + :param task_id: Task ID to pull XComs from. + :return: List of all XCom values if found, None if no XComs exist. + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) + # we need to make sure that we "atomically" send a request and get the response to that + # back so that two triggers don't end up interleaving requests and create a possible + # race condition where the wrong trigger reads the response. + with SUPERVISOR_COMMS.lock: + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXComSequenceSlice( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + start=None, + stop=None, + step=None, + ), + ) + msg = SUPERVISOR_COMMS.get_message() + + if not isinstance(msg, XComSequenceSliceResult): + raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}") + + if msg.root is not None: + return msg.root + return None + @staticmethod def serialize_value( value: Any, diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index bbc7394a87a0d..db0a169f52de7 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -338,7 +338,7 @@ def xcom_pull( run_id = self.run_id single_task_requested = isinstance(task_ids, (str, type(None))) - single_map_index_requested = isinstance(map_indexes, (int, type(None), ArgNotSet)) + single_map_index_requested = isinstance(map_indexes, (int, type(None))) if task_ids is None: # default to the current task if not provided @@ -346,11 +346,25 @@ def xcom_pull( elif isinstance(task_ids, str): task_ids = [task_ids] - map_indexes_iterable: Iterable[int | None] = [] - # If map_indexes is not provided, default to use the map_index of the calling task + # If map_indexes is not specified, pull xcoms for all map indexes for per task if isinstance(map_indexes, ArgNotSet): - map_indexes_iterable = [self.map_index] - elif isinstance(map_indexes, int) or map_indexes is None: + xcoms = [] + for t_id in task_ids: + values = XCom.get_all( + run_id=run_id, + key=key, + task_id=t_id, + dag_id=dag_id, + ) + if values is None: + xcoms.append(default) + else: + xcoms.extend(values) + return xcoms + + # Original logic when map_indexes is explicitly specified + map_indexes_iterable: Iterable[int | None] = [] + if isinstance(map_indexes, int) or map_indexes is None: map_indexes_iterable = [map_indexes] elif isinstance(map_indexes, Iterable): map_indexes_iterable = map_indexes @@ -360,10 +374,6 @@ def xcom_pull( ) xcoms = [] - # TODO: AIP 72 Execution API only allows working with a single map_index at a time - # this is inefficient and leads to task_id * map_index requests to the API. - # And we can't achieve the original behavior of XCom pull with multiple tasks - # directly now. for t_id, m_idx in product(task_ids, map_indexes_iterable): value = XCom.get_one( run_id=run_id, From 8b72c73fc121dce040f06f28cf1083e7c8461060 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 10 Jun 2025 16:27:08 +0530 Subject: [PATCH 2/8] adding tests --- .../execution_time/test_task_runner.py | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index a994e995bace1..ae4738d68e2e1 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -1421,7 +1421,7 @@ def execute(self, context): pytest.param("task_a", 0, {"a": 1, "b": 2}, id="task_id is str, map_index is int"), pytest.param("task_a", [0], [{"a": 1, "b": 2}], id="task_id is str, map_index is list"), pytest.param("task_a", None, {"a": 1, "b": 2}, id="task_id is str, map_index is None"), - pytest.param("task_a", NOTSET, {"a": 1, "b": 2}, id="task_id is str, map_index is ArgNotSet"), + pytest.param("task_a", NOTSET, [{"a": 1, "b": 2}], id="task_id is str, map_index is ArgNotSet"), pytest.param(["task_a"], 0, [{"a": 1, "b": 2}], id="task_id is list, map_index is int"), pytest.param(["task_a"], [0], [{"a": 1, "b": 2}], id="task_id is list, map_index is list"), pytest.param(["task_a"], None, [{"a": 1, "b": 2}], id="task_id is list, map_index is None"), @@ -1431,7 +1431,13 @@ def execute(self, context): pytest.param(None, 0, {"a": 1, "b": 2}, id="task_id is None, map_index is int"), pytest.param(None, [0], [{"a": 1, "b": 2}], id="task_id is None, map_index is list"), pytest.param(None, None, {"a": 1, "b": 2}, id="task_id is None, map_index is None"), - pytest.param(None, NOTSET, {"a": 1, "b": 2}, id="task_id is None, map_index is ArgNotSet"), + pytest.param(None, NOTSET, [{"a": 1, "b": 2}], id="task_id is None, map_index is ArgNotSet"), + pytest.param( + ["task_a", "task_b"], + NOTSET, + [{"a": 1, "b": 2}, {"c": 3, "d": 4}], + id="multiple task_ids, map_index is ArgNotSet", + ), ], ) def test_xcom_pull_return_values( @@ -1444,7 +1450,7 @@ def test_xcom_pull_return_values( ): """ Tests return value of xcom_pull under various combinations of task_ids and map_indexes. - The above test covers the expected calls to supervisor comms. + Also verifies the correct XCom method (get_one vs get_all) is called. """ class CustomOperator(BaseOperator): @@ -1455,13 +1461,28 @@ def execute(self, context): task = CustomOperator(task_id=test_task_id) runtime_ti = create_runtime_ti(task=task) - value = {"a": 1, "b": 2} - # API server returns serialised value for xcom result, staging it in that way - xcom_value = BaseXCom.serialize_value(value) - mock_supervisor_comms.get_message.return_value = XComResult(key="key", value=xcom_value) - - returned_xcom = runtime_ti.xcom_pull(key="key", task_ids=task_ids, map_indexes=map_indexes) - assert returned_xcom == expected_value + with patch.object(XCom, "get_one") as mock_get_one, patch.object(XCom, "get_all") as mock_get_all: + if map_indexes == NOTSET: + # Use side_effect to return different values for different tasks + def mock_get_all_side_effect(*, task_id, **kwargs): + if task_id == "task_b": + return [{"c": 3, "d": 4}] + return [{"a": 1, "b": 2}] + + mock_get_all.side_effect = mock_get_all_side_effect + mock_get_one.return_value = None + else: + mock_get_one.return_value = {"a": 1, "b": 2} + mock_get_all.return_value = None + + xcom = runtime_ti.xcom_pull(key="key", task_ids=task_ids, map_indexes=map_indexes) + assert xcom == expected_value + if map_indexes == NOTSET: + assert mock_get_all.called + assert not mock_get_one.called + else: + assert mock_get_one.called + assert not mock_get_all.called def test_get_param_from_context( self, mocked_parse, make_ti_context, mock_supervisor_comms, create_runtime_ti From ac70e5cc3ce5976cff4cca47f7ea670caa606145 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 10 Jun 2025 16:27:41 +0530 Subject: [PATCH 3/8] adding tests --- task-sdk/tests/task_sdk/execution_time/test_task_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index ae4738d68e2e1..c89f976066b50 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -1464,7 +1464,7 @@ def execute(self, context): with patch.object(XCom, "get_one") as mock_get_one, patch.object(XCom, "get_all") as mock_get_all: if map_indexes == NOTSET: # Use side_effect to return different values for different tasks - def mock_get_all_side_effect(*, task_id, **kwargs): + def mock_get_all_side_effect(task_id, **kwargs): if task_id == "task_b": return [{"c": 3, "d": 4}] return [{"a": 1, "b": 2}] From 9252bc8997d0f96b4cecdc2ab269c43adf2f1a40 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 11 Jun 2025 13:04:59 +0530 Subject: [PATCH 4/8] fixing one case --- task-sdk/src/airflow/sdk/execution_time/task_runner.py | 6 +++++- task-sdk/tests/task_sdk/execution_time/test_task_runner.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index db0a169f52de7..c847d9d5b1b66 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -346,7 +346,7 @@ def xcom_pull( elif isinstance(task_ids, str): task_ids = [task_ids] - # If map_indexes is not specified, pull xcoms for all map indexes for per task + # If map_indexes is not specified, pull xcoms from all map indexes for each task if isinstance(map_indexes, ArgNotSet): xcoms = [] for t_id in task_ids: @@ -360,6 +360,10 @@ def xcom_pull( xcoms.append(default) else: xcoms.extend(values) + + # For single task pulling from unmapped task, return single value + if single_task_requested and len(xcoms) == 1: + return xcoms[0] return xcoms # Original logic when map_indexes is explicitly specified diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index c89f976066b50..62efd3181c33f 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -1421,7 +1421,6 @@ def execute(self, context): pytest.param("task_a", 0, {"a": 1, "b": 2}, id="task_id is str, map_index is int"), pytest.param("task_a", [0], [{"a": 1, "b": 2}], id="task_id is str, map_index is list"), pytest.param("task_a", None, {"a": 1, "b": 2}, id="task_id is str, map_index is None"), - pytest.param("task_a", NOTSET, [{"a": 1, "b": 2}], id="task_id is str, map_index is ArgNotSet"), pytest.param(["task_a"], 0, [{"a": 1, "b": 2}], id="task_id is list, map_index is int"), pytest.param(["task_a"], [0], [{"a": 1, "b": 2}], id="task_id is list, map_index is list"), pytest.param(["task_a"], None, [{"a": 1, "b": 2}], id="task_id is list, map_index is None"), @@ -1431,13 +1430,14 @@ def execute(self, context): pytest.param(None, 0, {"a": 1, "b": 2}, id="task_id is None, map_index is int"), pytest.param(None, [0], [{"a": 1, "b": 2}], id="task_id is None, map_index is list"), pytest.param(None, None, {"a": 1, "b": 2}, id="task_id is None, map_index is None"), - pytest.param(None, NOTSET, [{"a": 1, "b": 2}], id="task_id is None, map_index is ArgNotSet"), pytest.param( ["task_a", "task_b"], NOTSET, [{"a": 1, "b": 2}, {"c": 3, "d": 4}], id="multiple task_ids, map_index is ArgNotSet", ), + pytest.param("task_a", NOTSET, {"a": 1, "b": 2}, id="task_id is str, map_index is ArgNotSet"), + pytest.param(None, NOTSET, {"a": 1, "b": 2}, id="task_id is None, map_index is ArgNotSet"), ], ) def test_xcom_pull_return_values( From 4be8c12a5f6bb61127b3cf4e7b2736d4d7483b7e Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 11 Jun 2025 13:15:13 +0530 Subject: [PATCH 5/8] fixing broken tests --- task-sdk/tests/task_sdk/execution_time/test_task_runner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 62efd3181c33f..2f7e197c4c7c3 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -1931,13 +1931,11 @@ def execute(self, context): runtime_ti = create_runtime_ti(task=task) run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_xcom_backend.get_one.assert_called_once_with( + mock_xcom_backend.get_all.assert_called_once_with( key="key", dag_id="test_dag", task_id="pull_task", run_id="test_run", - map_index=-1, - include_prior_dates=False, ) assert not any( From a2c01b1176f41e29dbd5238e93b14a231216441a Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 11 Jun 2025 13:43:39 +0530 Subject: [PATCH 6/8] review comments from TP --- task-sdk/src/airflow/sdk/bases/xcom.py | 15 +++++---------- .../src/airflow/sdk/execution_time/task_runner.py | 5 +---- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index fb68cb8274401..c9b777daca32e 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -289,15 +289,12 @@ def get_all( dag_id: str, task_id: str, run_id: str, - ) -> Any | None: + ) -> Any: """ Retrieve all XCom values for a task, typically from all map indexes. - This method returns "full" XCom values (i.e. uses ``deserialize_value`` - from the XCom backend). - - If there are no results, *None* is returned. If XCom entries exist, - a list containing all matching XCom values is returned. + XComSequenceSliceResult can never have *None* in it, it returns an empty list + if no values were found. This is particularly useful for getting all XCom values from all map indexes of a mapped task at once. @@ -306,7 +303,7 @@ def get_all( :param run_id: DAG run ID for the task. :param dag_id: DAG ID to pull XComs from. :param task_id: Task ID to pull XComs from. - :return: List of all XCom values if found, None if no XComs exist. + :return: List of all XCom values if found. """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS @@ -332,9 +329,7 @@ def get_all( if not isinstance(msg, XComSequenceSliceResult): raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}") - if msg.root is not None: - return msg.root - return None + return msg.root @staticmethod def serialize_value( diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index c847d9d5b1b66..0b72b7a1b2fc0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -356,10 +356,7 @@ def xcom_pull( task_id=t_id, dag_id=dag_id, ) - if values is None: - xcoms.append(default) - else: - xcoms.extend(values) + xcoms.extend(values) # For single task pulling from unmapped task, return single value if single_task_requested and len(xcoms) == 1: From 2b9dde897804c15623bf127a6dfd9e5db36bc6cd Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 11 Jun 2025 15:42:32 +0530 Subject: [PATCH 7/8] improved code Co-authored-by: Tzu-ping Chung --- task-sdk/src/airflow/sdk/execution_time/task_runner.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 0b72b7a1b2fc0..2374520e63f21 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -348,15 +348,16 @@ def xcom_pull( # If map_indexes is not specified, pull xcoms from all map indexes for each task if isinstance(map_indexes, ArgNotSet): - xcoms = [] - for t_id in task_ids: - values = XCom.get_all( + xcoms = [ + value + for t_id in task_ids + for value in XCom.get_all( run_id=run_id, key=key, task_id=t_id, dag_id=dag_id, ) - xcoms.extend(values) + ] # For single task pulling from unmapped task, return single value if single_task_requested and len(xcoms) == 1: From 67a5a20589c83f4326fdb768ed614bd31e41bfab Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 11 Jun 2025 17:10:01 +0530 Subject: [PATCH 8/8] fixing broken test --- .../execution_time/test_task_runner.py | 51 ++++++++++++++----- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 2f7e197c4c7c3..d786fdaa5b2bc 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -77,6 +77,7 @@ GetTICount, GetVariable, GetXCom, + GetXComSequenceSlice, OKResponse, PrevSuccessfulDagRunResult, SetRenderedFields, @@ -91,6 +92,7 @@ TriggerDagRun, VariableResult, XComResult, + XComSequenceSliceResult, ) from airflow.sdk.execution_time.context import ( ConnectionAccessor, @@ -1113,7 +1115,7 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti, mock_s task = BaseOperator(task_id="hello") # Assume the context is sent from the API server - # `task_sdk/tests/api/test_client.py::test_task_instance_start` checks the context is received + # `task-sdk/tests/api/test_client.py::test_task_instance_start` checks the context is received # from the API server runtime_ti = create_runtime_ti(task=task, dag_id="basic_task") @@ -1387,7 +1389,17 @@ def execute(self, context): runtime_ti = create_runtime_ti(task=task, **extra_for_ti) ser_value = BaseXCom.serialize_value(xcom_values) - mock_supervisor_comms.get_message.return_value = XComResult(key="key", value=ser_value) + + def mock_get_message_side_effect(*args, **kwargs): + calls = mock_supervisor_comms.send_request.call_args_list + if calls: + last_call = calls[-1] + msg = last_call[1]["msg"] + if isinstance(msg, GetXComSequenceSlice): + return XComSequenceSliceResult(root=[ser_value]) + return XComResult(key="key", value=ser_value) + + mock_supervisor_comms.get_message.side_effect = mock_get_message_side_effect run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) @@ -1403,17 +1415,30 @@ def execute(self, context): task_id = test_task_id for map_index in map_indexes: if map_index == NOTSET: - map_index = -1 - mock_supervisor_comms.send_request.assert_any_call( - log=mock.ANY, - msg=GetXCom( - key="key", - dag_id="test_dag", - run_id="test_run", - task_id=task_id, - map_index=map_index, - ), - ) + mock_supervisor_comms.send_request.assert_any_call( + log=mock.ANY, + msg=GetXComSequenceSlice( + key="key", + dag_id="test_dag", + run_id="test_run", + task_id=task_id, + start=None, + stop=None, + step=None, + ), + ) + else: + expected_map_index = map_index if map_index is not None else None + mock_supervisor_comms.send_request.assert_any_call( + log=mock.ANY, + msg=GetXCom( + key="key", + dag_id="test_dag", + run_id="test_run", + task_id=task_id, + map_index=expected_map_index, + ), + ) @pytest.mark.parametrize( "task_ids, map_indexes, expected_value",